diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8669c25c452b53da48239bc20c9a2d3528e75422..db4b1581ae671b1e676e215c9a80dfaab832fa21 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -90,7 +90,7 @@ Bazel BUILD files also need to include a license section, e.g., Changes to TensorFlow C++ code should conform to [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). -Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do: +Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do: ```bash apt-get install -y clang-tidy diff --git a/README.md b/README.md index e7f4080cf44f9a4f3346b33f03ac9930befa9a6a..63853137cfd30b396f8c7d204811f3e4a1794c07 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ $ python 42 >>> sess.close() ``` +Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). ## Contribution guidelines @@ -91,9 +92,10 @@ The TensorFlow project strives to abide by generally accepted best practices in ### Community Supported Builds -| Build Type | Status | Artifacts | -| --- | --- | --- | -| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | +| Build Type | Status | Artifacts | +| --- | --- | --- | +| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | +| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | ## For more information diff --git a/RELEASE.md b/RELEASE.md index 84d9d52868ecd55d38d6073315749d11c2340e8c..e09e9c6190f57adec67c2ae1d85848dabfd9c2a7 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,62 @@ +# Release 1.9.0 + +## Major Features And Improvements +* Update tf.keras to the Keras 2.1.6 API. +* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. +* Adding support of core feature columns and losses to gradient boosted trees estimators. +* The distributions.Bijector API supports broadcasting for Bijectors with new API changes. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/distributions/bijectors/Bijector) for more details. +* Layered variable names have changed in the following conditions: + * Using `tf.keras.layers` with custom variable scopes. + * Using `tf.layers` in a subclassed `tf.keras.Model` class. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details + +## Breaking Chances + * If you're opening empty variable scopes; replace `variable_scope`('', ...) by `variable_scope`(`tf.get_variable_scope()`, ...). + +## Bug Fixes and Other Changes +* `tf.data`: + * The `DatasetBase::DebugString()` method is now `const`. + * Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets. +* Eager Execution: +* `tf.keras`: + * Move Keras code out of _impl folder and remove API files. + * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. + * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. +* Accelerated Linear Algebra (XLA): +* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). +* `tf.contrib`: + * Add `tf.contrib.data.choose_from_datasets()`. + * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`. + * `tf.contrib.framework.zero_initializer` supports ResourceVariable. + * Adding "constrained_optimization" to tensorflow/contrib. +* Other: + * Add GCS Configuration Ops. + * Changing signature of `MakeIterator` to enable propagating error status. + * KL divergence for two Dirichlet distributions. + * More consistent GcsFileSystem behavior for certain reads past EOF. + * Update benchmark for tf.scan to match ranges across eager and graph modes. + * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. + * Add optional `args` argument to `Dataset.from_generator()`. + * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). + * Benchmark for tf.scan in graph and eager modes. + * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. + * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce RPC calls for looking up the embeddings when there are repeated ids in the batch. + * Support indicator column in boosted trees. + * Prevent `tf.gradients()` from backpropagating through integer tensors. + * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. + * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary. + * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints. + * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. + * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product. + * Allow LinearOperator to broadcast. + * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other. + + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Abdullah Alrasheed, Achal Shah, Ad-530, ADiegoCAlonso, Aditya Yogi, Ag Ramesh, akindyakov, Andy Kernahan, Anya Petrova, Aurelien Geron, Ben, Ben Barsdell, Bhavani-Subramanian, braincodercn, Brett Koonce, Brian Nemsick, Brian Zier, Bryan Heden, candy.dc, cclauss, Clayne Robison, ctiijima, Dalmo Cirne, David Norman, David T.H. Kao, DosLin, ekelsen, Elson Rodriguez, Erik Smistad, Felix Abecassis, Fergal Cotter, fo40225, foo0x29a, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, gdh1995, Geoffrey Irving, Giuseppe, gracehoney, Guido Zuidhof, Guillaume Klein, Guozhong Zhuang, Haggai, Harald Husum, imsheridan, Ivan Zhang, Jan Zikes, Jayaram Bobba, Jesse Benson, Jesse Gumz, Jiajia Li, Jie, jinghuangintel, Jingwen, jjsjann123, Joe Yearsley, Joel Hestness, Joel Shor, josephyearsley, Junpeng Lao, Karol M. Langner, Kb Sriram, krantideep95, Krish Ravindranath, Letian Feng, Loo Rong Jie, Lukas Geiger, Maciej, Mahmoud Abuzaina, ManHyuk, Mark Ryan, mbhuiyan, Michal Turek, Mostafa Alaa, Myungsung Kwak, Nand Dalal, Nehal J Wani, Neil Tenenholtz, ngc92, Nicholas Nadeau, P.Eng., Avs, Niranjan Hasabnis, P-Hidringer, Paul Van Eck, Peng Yu, Qing Zhao, Qingying Chen, Quanlong, Rajendra Arora, Rholais Lii, rmanyari, Robin Richtsfeld, Russell Klopfer, Sagi, Sam Sendelbach, Sandeep N Gupta, Sandip Giri, Sarah Edkins, Scott Tseng, Sdalbsoo, Sergii Khomenko, Seungwoo Choi (Biggie), Seyed Majid Azimi, Shaoning Zeng, shengfuintel, Siu Kei, Muk, Smit Shilu, soonson, Stefan Schweter, Sukhwan Kim, Sunitha Kambhampati, Taehoon Lee, tamimaddari82, Tang, Wenyi, Ted Chang, u2takey, Utkarsh Upadhyay, Vadim Markovtsev, voegtlel, Wai Hon Law, wangsiyu, Wenhao Hu, wenhao.hu, William D. Irons, Yan Facai (颜发才), Yanbo Liang, Yihong Wang, Yilei (Dolee) Yang, Yong Tang, Yuan (Terry) Tang + # Release 1.8.0 ## Major Features And Improvements @@ -404,14 +463,6 @@ answered questions, and were part of inspiring discussions. # Release 1.4.0 -## Major Features And Improvements -* `tf.keras` is now part of the core TensorFlow API. -* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of - the core TensorFlow API. - * The API is now subject to backwards compatibility guarantees. - -# Release 1.4.0 - ## Major Features And Improvements * `tf.keras` is now part of the core TensorFlow API. * [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of diff --git a/SECURITY.md b/SECURITY.md index 01886b613e5d93793953124331b57f075fe7a373..e2f6ff353a3c04a6ec6b8ccbaeb75db59fa22d54 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -168,7 +168,7 @@ below). Please use a descriptive subject line for your report email. After the initial reply to your report, the security team will endeavor to keep you informed of -the progress being made towards a fix and announcement. +the progress being made towards a fix and announcement. In addition, please include the following information along with your report: @@ -242,9 +242,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= -----END PGP PUBLIC KEY BLOCK----- ``` -### Known vulnerabilities - -| Type | Versions affected | Reported by | Additional Information | -|--------------------|:-----------------:|-----------------------|-----------------------------| -| Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | +### Known Vulnerabilities +For a list of known vulnerabilities and security advisories for TensorFlow, +(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md)[click here]. diff --git a/WORKSPACE b/WORKSPACE index 4ddfb9a3832ea1ea639ace887e1d601bdd857086..fd7570a80ae2ee0087f7d2fd771fcce5b9690028 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -22,26 +22,10 @@ check_bazel_version_at_least("0.10.0") load("//tensorflow:workspace.bzl", "tf_workspace") -# Uncomment and update the paths in these entries to build the Android demo. -#android_sdk_repository( -# name = "androidsdk", -# api_level = 23, -# # Ensure that you have the build_tools_version below installed in the -# # SDK manager as it updates periodically. -# build_tools_version = "26.0.1", -# # Replace with path to Android SDK on your system -# path = "", -#) -# -#android_ndk_repository( -# name="androidndk", -# path="", -# # This needs to be 14 or higher to compile TensorFlow. -# # Please specify API level to >= 21 to build for 64-bit -# # archtectures or the Android NDK will automatically select biggest -# # API level that it supports without notice. -# # Note that the NDK version is not the API level. -# api_level=14) +load("//third_party/android:android_configure.bzl", "android_configure") +android_configure(name="local_config_android") +load("@local_config_android//:android.bzl", "android_workspace") +android_workspace() # Please add all new TensorFlow dependencies in workspace.bzl. tf_workspace() diff --git a/configure.py b/configure.py index 6d9aba61bbc73ba1b80321d6859877c371dc5427..ada342a50ab5104509156d3e44e6435a308255a3 100644 --- a/configure.py +++ b/configure.py @@ -498,10 +498,6 @@ def set_cc_opt_flags(environ_cp): if not is_ppc64le() and not is_windows(): write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') - # TODO(mikecase): Remove these default defines once we are able to get - # TF Lite targets building without them. - write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') - write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -674,8 +670,9 @@ def create_android_ndk_rule(environ_cp): error_msg=('The path %s or its child file "source.properties" ' 'does not exist.') ) - - write_android_ndk_workspace_rule(android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL', + check_ndk_level(android_ndk_home_path)) def create_android_sdk_rule(environ_cp): @@ -737,41 +734,12 @@ def create_android_sdk_rule(environ_cp): error_msg=('The selected SDK does not have build-tools version %s ' 'available.')) - write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level) - - -def write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level): - print('Writing android_sdk_workspace rule.\n') - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_sdk_repository( - name="androidsdk", - api_level=%s, - path="%s", - build_tools_version="%s")\n -""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) - - -def write_android_ndk_workspace_rule(android_ndk_home_path): - print('Writing android_ndk_workspace rule.') - ndk_api_level = check_ndk_level(android_ndk_home_path) - if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: - print('WARNING: The API level of the NDK in %s is %s, which is not ' - 'supported by Bazel (officially supported versions: %s). Please use ' - 'another version. Compiling Android targets may result in confusing ' - 'errors.\n' % (android_ndk_home_path, ndk_api_level, - _SUPPORTED_ANDROID_NDK_VERSIONS)) - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_ndk_repository( - name="androidndk", - path="%s", - api_level=%s)\n -""" % (android_ndk_home_path, ndk_api_level)) + write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION', + android_build_tools_version) + write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', + android_api_level) + write_action_env_to_bazelrc('ANDROID_SDK_HOME', + android_sdk_home_path) def check_ndk_level(android_ndk_home_path): @@ -784,18 +752,16 @@ def check_ndk_level(android_ndk_home_path): revision = re.search(r'Pkg.Revision = (\d+)', filedata) if revision: - return revision.group(1) - return None - - -def workspace_has_any_android_rule(): - """Check the WORKSPACE for existing android_*_repository rules.""" - with open(_TF_WORKSPACE, 'r') as f: - workspace = f.read() - has_any_rule = re.search(r'^android_[ns]dk_repository', - workspace, - re.MULTILINE) - return has_any_rule + ndk_api_level = revision.group(1) + else: + raise Exception('Unable to parse NDK revision.') + if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_api_level, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + return ndk_api_level def set_gcc_host_compiler_path(environ_cp): @@ -1227,7 +1193,7 @@ def set_tf_cuda_compute_capabilities(environ_cp): # Check whether all capabilities from the input is valid all_valid = True # Remove all whitespace characters before splitting the string - # that users may insert by accident, as this will result in error + # that users may insert by accident, as this will result in error tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split()) for compute_capability in tf_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) @@ -1431,6 +1397,10 @@ def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') +def set_build_strip_flag(): + write_to_bazelrc('build --strip=always') + + def set_windows_build_flags(): if is_windows(): # The non-monolithic build is not supported yet @@ -1553,23 +1523,18 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) + set_build_strip_flag() set_windows_build_flags() - if workspace_has_any_android_rule(): - print('The WORKSPACE file has at least one of ["android_sdk_repository", ' - '"android_ndk_repository"] already set. Will not ask to help ' - 'configure the WORKSPACE. Please delete the existing rules to ' - 'activate the helper.\n') - else: - if get_var( - environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', - False, - ('Would you like to interactively configure ./WORKSPACE for ' - 'Android builds?'), - 'Searching for NDK and SDK installations.', - 'Not configuring the WORKSPACE for Android builds.'): - create_android_ndk_rule(environ_cp) - create_android_sdk_rule(environ_cp) + if get_var( + environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', + False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), + 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): + create_android_ndk_rule(environ_cp) + create_android_sdk_rule(environ_cp) print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See tools/bazel.rc for ' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f2ad16fa04f5beb6616c58c28d0f0c460c3e3a17..6d134dbb80cb8c3dcf15b2ba20783870a67e9a62 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -19,6 +19,10 @@ load( "//tensorflow/core:platform/default/build_config.bzl", "tf_additional_binary_deps", ) +load( + "//tensorflow/tools/api/generator:api_gen.bzl", + "gen_api_init_files", # @unused +) # Config setting for determining if we are building for Android. config_setting( @@ -471,7 +475,7 @@ tf_cc_shared_object( # excludes all but a subset of function names. # On MacOS, the linker does not support version_script, but has an # an "-exported_symbols_list" command. -z defs disallows undefined -# symbols in object files and -s strips the output. +# symbols in object files. tf_cc_shared_object( name = "libtensorflow.so", @@ -485,7 +489,6 @@ tf_cc_shared_object( "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow/c:version_script.lds)", ], @@ -511,7 +514,6 @@ tf_cc_shared_object( "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow:tf_version_script.lds)", ], @@ -536,13 +538,19 @@ exports_files( ], ) +gen_api_init_files( + name = "tensorflow_python_api_gen", + srcs = ["api_template.__init__.py"], + root_init_template = "api_template.__init__.py", +) + py_library( name = "tensorflow_py", - srcs = ["__init__.py"], + srcs = [ + ":tensorflow_python_api_gen", + "//tensorflow/python/estimator/api:estimator_python_api_gen", + ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ - "//tensorflow/python", - "//tensorflow/tools/api/generator:python_api", - ], + deps = ["//tensorflow/python"], ) diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index c8683e3976c90add3f1f54d8e575c798327e9273..440e9f8dbd2f4b2a2ab78eaaf26408584e7c1446 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -22,9 +22,6 @@ from __future__ import print_function # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import -# pylint: disable=wildcard-import -from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin -# pylint: enable=wildcard-import from tensorflow.python.util.lazy_loader import LazyLoader contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9662d7b478ba61c69edc20b0d47293f9939e7881 --- /dev/null +++ b/tensorflow/api_template.__init__.py @@ -0,0 +1,58 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bring in all of the public TensorFlow interface into this module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + +try: + import os # pylint: disable=g-import-not-at-top + # Add `estimator` attribute to allow access to estimator APIs via + # "tf.estimator..." + from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top + + # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." + # style imports. + from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top + __path__ += [os.path.dirname(estimator_api.__file__)] + del estimator_api + del os +except (ImportError, AttributeError): + print('tf.estimator package not installed.') + +from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +del LazyLoader + +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable + +del absolute_import +del division +del print_function + +# These symbols appear because we import the python package which +# in turn imports from tensorflow.core and tensorflow.python. They +# must come from this module. So python adds these symbols for the +# resolution to succeed. +# pylint: disable=undefined-variable +del python +del core +# pylint: enable=undefined-variable diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index b86b277ac3200b88ae03490a6c1b64d464e81950..cb0b093ad260e000dcef9d1123e967a77cf1a041 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -631,7 +631,22 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, "Failed to allocate memory to serialize message of type '", in.GetTypeName(), "' and size ", proto_size); } - in.SerializeToArray(buf, proto_size); + // SerializeToArray takes size as an int. + // This next 'if' is a workaround till we update to depend on a version + // of protocol buffers that includes + // https://github.com/google/protobuf/pull/4739 + if (proto_size > std::numeric_limits::max()) { + return InvalidArgument("Cannot serialize protocol buffer of type ", + in.GetTypeName(), " as the serialized size (", + proto_size, + "bytes) would be larger than the limit (", + std::numeric_limits::max(), " bytes)"); + } + if (!in.SerializeToArray(buf, proto_size)) { + return InvalidArgument("Unable to serialize ", in.GetTypeName(), + " protocol buffer, perhaps the serialized size (", + proto_size, " bytes) is too large?"); + } out->data = buf; out->length = proto_size; out->data_deallocator = [](void* data, size_t length) { diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index cd19cf8d624d9b914b61132f93d918b046cdbd30..c16aba666ee6974fed5351c2d9ac291dcbcdecab 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 14321191625e448637aa44a7f6a17820159b97c2..f265da2c2c89c0e9caf14f2213c606fcb69997e0 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -14,6 +14,7 @@ tf_cuda_library( name = "c_api", srcs = [ "c_api.cc", + "c_api_debug.cc", "c_api_internal.h", ], hdrs = ["c_api.h"], @@ -24,10 +25,10 @@ tf_cuda_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", @@ -45,10 +46,22 @@ tf_cuda_library( "//tensorflow:with_xla_support": [ "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:xla_device", ], "//conditions:default": [], }) + [ "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core:gpu_runtime", ], ) @@ -59,7 +72,6 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = [ ":c_api", - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", @@ -69,70 +81,65 @@ tf_cuda_library( "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", ], ) -tf_cuda_cc_test( - name = "c_api_test", - srcs = ["c_api_test.cc"], - extra_copts = tfe_xla_copts(), - tags = [ - "guitar", - "multi_gpu", +tf_cuda_library( + name = "c_api_test_util", + testonly = 1, + srcs = ["c_api_test_util.cc"], + hdrs = ["c_api_test_util.h"], + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", ], deps = [ ":c_api", "//tensorflow/c:c_test_util", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", - "//tensorflow/core:test_main", ], ) -tf_cuda_library( - name = "runtime", - srcs = ["runtime.cc"], - hdrs = ["runtime.h"], - copts = tf_copts(), - visibility = ["//tensorflow:internal"], - deps = select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", - ], - "//conditions:default": [ - "//tensorflow/c:c_api", - "//tensorflow/core:core_cpu", - "//tensorflow/core/common_runtime/eager:kernel_and_device", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - ], - }), -) - -tf_cc_test( - name = "runtime_test", - srcs = ["runtime_test.cc"], +tf_cuda_cc_test( + name = "c_api_test", + srcs = [ + "c_api_debug_test.cc", + "c_api_test.cc", + ], + extra_copts = tfe_xla_copts(), + tags = [ + "guitar", + "multi_gpu", + ], deps = [ - ":runtime", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", + ":c_api", + ":c_api_test_util", + "//tensorflow/c:c_test_util", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 3bf071f3abaac7dfd4113964fd49cd9322913bd5..81221c4078bec9820ee187efdf0314da378be62b 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -32,15 +31,22 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -67,10 +73,121 @@ string DeviceName(const tensorflow::Device* d) { return (d == nullptr) ? "cpu:0" : d->name(); } -#ifdef TENSORFLOW_EAGER_USE_XLA -std::atomic_int_fast64_t func_id_generator(0); -#endif // TENSORFLOW_EAGER_USE_XLA +tensorflow::Status GetAllRemoteDevices( + const std::vector& remote_workers, + tensorflow::WorkerCacheInterface* worker_cache, + std::unique_ptr* device_mgr) { + std::vector remote_devices; + tensorflow::Status status; + // TODO(nareshmodi) do this in parallel instead of serially. + for (const string& remote_worker : remote_workers) { + tensorflow::Notification n; + tensorflow::NewRemoteDevices( + tensorflow::Env::Default(), worker_cache, remote_worker, + [&status, &n, &remote_devices]( + const tensorflow::Status& s, + std::vector* devices) { + status = s; + if (s.ok()) { + for (tensorflow::Device* d : *devices) { + remote_devices.push_back(d); + } + } + n.Notify(); + }); + n.WaitForNotification(); + } + std::unique_ptr remote_device_mgr( + new tensorflow::DeviceMgr(remote_devices)); + + TF_RETURN_IF_ERROR(status); + + *device_mgr = std::move(remote_device_mgr); + return tensorflow::Status::OK(); +} + +tensorflow::Status CreateRemoteContexts( + const std::vector& remote_workers, + tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, + tensorflow::gtl::FlatMap* remote_contexts) { + for (int i = 0; i < remote_workers.size(); i++) { + const string& remote_worker = remote_workers[i]; + + tensorflow::eager::CreateContextRequest request; + tensorflow::eager::CreateContextResponse response; + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, + &parsed_name)) { + return tensorflow::errors::InvalidArgument( + "Unable to parse ", remote_worker, " as a device name"); + } + request.mutable_server_def()->set_job_name(parsed_name.job); + request.mutable_server_def()->set_task_index(parsed_name.task); + request.set_async(async); + auto* eager_client = remote_eager_workers->GetClient(remote_worker); + if (eager_client == nullptr) { + return tensorflow::errors::Internal( + "Cannot find a client for the given target:", remote_worker); + } + tensorflow::Notification n; + tensorflow::Status status; + // TODO(nareshmodi) do this in parallel instead of serially. + eager_client->CreateContextAsync( + &request, &response, [&status, &n](const tensorflow::Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + TF_RETURN_IF_ERROR(status); + + remote_contexts->emplace(remote_worker, response.context_id()); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, + TFE_Context** ctx) { + string worker_name = tensorflow::strings::StrCat( + "/job:", opts->server_def.job_name(), + "/replica:0/task:", opts->server_def.task_index()); + std::unique_ptr server; + TF_RETURN_IF_ERROR( + tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server)); + + TF_RETURN_IF_ERROR(server->Start()); + + std::vector remote_workers; + server->master_env()->worker_cache->ListWorkers(&remote_workers); + remote_workers.erase( + std::remove(remote_workers.begin(), remote_workers.end(), worker_name), + remote_workers.end()); + + std::unique_ptr remote_device_mgr; + TF_RETURN_IF_ERROR(GetAllRemoteDevices( + remote_workers, server->master_env()->worker_cache, &remote_device_mgr)); + + std::shared_ptr channel_cache = + server->channel_cache(); + std::unique_ptr remote_eager_workers( + tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); + // Initialize remote eager workers. + tensorflow::gtl::FlatMap remote_contexts; + TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, + remote_eager_workers.get(), + opts->async, &remote_contexts)); + + tensorflow::RemoteRendezvous* r = + server->worker_env()->rendezvous_mgr->Find(0); + + auto* device_mgr = server->worker_env()->device_mgr; + *ctx = new TFE_Context(opts->session_options.options, opts->policy, + opts->async, device_mgr, r, std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts); + + return tensorflow::Status::OK(); +} } // namespace extern "C" { @@ -91,6 +208,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( options->policy = policy; } +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status) { + if (!options->server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + } +} + TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { @@ -100,17 +226,23 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { + if (!opts->server_def.job_name().empty()) { + TFE_Context* ctx = nullptr; + status->status = NewRemoteAwareTFE_Context(opts, &ctx); + return ctx; + } + std::vector devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", &devices); - if (!status->status.ok()) { - return nullptr; - } + if (!status->status.ok()) return nullptr; std::unique_ptr device_mgr( new tensorflow::DeviceMgr(devices)); + tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); + return new TFE_Context(opts->session_options.options, opts->policy, opts->async, std::move(device_mgr), r); } @@ -119,7 +251,10 @@ void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; - ctx->context.device_mgr()->ListDeviceAttributes(&list->response); + ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response); + if (ctx->context.remote_device_mgr()) { + ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response); + } return list; } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index c06ce84a8c578aa60dd626c24bd58098b78ae750..1862af3ce2f505a6e83b4805417eaf335ed07bc0 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -81,6 +81,16 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); +// A tensorflow.ServerDef specifies remote workers (in addition to the current +// workers name). Operations created on this context can then be executed on +// any of these remote workers by setting an appropriate device. +// +// If the following is set, all servers identified by the +// ServerDef must be up when the context is created. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status); + // Destroy an options object. TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); @@ -181,6 +191,45 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status); +// Debugging/Profiling information for TFE_TensorHandle +// +// TFE_TensorDebugInfo contains information useful for debugging and +// profiling tensors. +typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo; + +// Retrieves TFE_TensorDebugInfo for `handle`. +// If TFE_TensorHandleTensorDebugInfo succeeds, `status` is set to OK and caller +// is responsible for deleting returned TFE_TensorDebugInfo. +// If TFE_TensorHandleTensorDebugInfo fails, `status` is set to appropriate +// error and nullptr is returned. This function can block till the operation +// that produces `handle` has completed. +TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( + TFE_TensorHandle* handle, TF_Status* status); + +// Deletes `debug_info`. +TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of dimensions used to represent the tensor on its device. +// The number of dimensions used to reprensent the tensor on device can be +// different from the number returned by TFE_TensorHandleNumDims. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of elements in dimension `dim_index`. +// Tensor representation on device can be transposed from its representation +// on host. The data contained in dimension `dim_index` on device +// can correspond to the data contained in another dimension in on-host +// representation. The dimensions are indexed using the standard TensorFlow +// major-to-minor order (slowest varying dimension first), +// not the XLA's minor-to-major order. +// On-device dimensions can be padded. TFE_TensorDebugInfoOnDeviceDim returns +// the number of elements in a dimension after padding. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim( + TFE_TensorDebugInfo* debug_info, int dim_index); + // Description of the TensorFlow op to execute. // // Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e., diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc new file mode 100644 index 0000000000000000000000000000000000000000..5006b76f1981d068e99a2c081115ebb3a66d8c7f --- /dev/null +++ b/tensorflow/c/eager/c_api_debug.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#ifdef TENSORFLOW_EAGER_USE_XLA +#include "tensorflow/compiler/jit/xla_device.h" +#endif // TENSORFLOW_EAGER_USE_XLA + +using tensorflow::int64; +using tensorflow::string; + +namespace { + +std::vector TensorShapeAsVector(TFE_TensorHandle* handle, + TF_Status* status) { + std::vector shape; + int rank = TFE_TensorHandleNumDims(handle, status); + if (!status->status.ok()) { + return shape; + } + shape.reserve(rank); + for (int i = 0; i < rank; ++i) { + shape.push_back(TFE_TensorHandleDim(handle, i, status)); + if (!status->status.ok()) { + return shape; + } + } + return shape; +} + +} // namespace + +extern "C" { + +TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( + TFE_TensorHandle* handle, TF_Status* status) { + const tensorflow::Tensor* tensor; + status->status = handle->handle->Tensor(&tensor); + if (!status->status.ok()) { + return nullptr; + } + + tensorflow::Device* device; + status->status = handle->handle->Device(&device); + if (!status->status.ok()) { + return nullptr; + } + +#ifdef TENSORFLOW_EAGER_USE_XLA + // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. + tensorflow::XlaDevice* xla_device = + dynamic_cast(device); + if (xla_device != nullptr) { + tensorflow::XlaDevice::PaddedShapeFn shape_fn = + xla_device->metadata().padded_shape_fn(); + xla::Shape padded_shape; + status->status = shape_fn(*tensor, &padded_shape); + if (!status->status.ok()) { + return nullptr; + } + if (VLOG_IS_ON(3)) { + std::vector shape_to_log = TensorShapeAsVector(handle, status); + if (!status->status.ok()) { + // Ignore the status here as we are simply logging. + status->status = tensorflow::Status::OK(); + } else { + VLOG(3) << "Fully padded shape of [" + << tensorflow::str_util::Join(shape_to_log, ", ") << "] is " + << padded_shape.DebugString(); + } + } + + if (xla::ShapeUtil::IsTuple(padded_shape)) { + if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { + // Currently, the only case of XlaTensor containing a tuple shape is to + // represent 64 bit ints, doubles, and complex numbers (we don't support + // 64bit complex numbers). + status->status = tensorflow::errors::InvalidArgument( + "XlaTensors should only contain tuples of size 2. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + + // shape0 is not a const& because we will assign it to padded_shape below. + // It is illegal to assign a part of a message to itself. + xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); + const xla::Shape& shape1 = + xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); + if (xla::ShapeUtil::IsTuple(shape0) || xla::ShapeUtil::IsTuple(shape1)) { + status->status = tensorflow::errors::InvalidArgument( + "XlaTensors should not contain nested tuples. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + if (!xla::ShapeUtil::Equal(shape0, shape1)) { + status->status = tensorflow::errors::InvalidArgument( + "Subshapes of XlaTensors should be the same. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + + // Since the only case we handle here are two equal subshapes, we + // simply return one of them. The caller will interpret it as this + // shape directly storing the 64bit types. This approximation is good + // enough for this API's debugging use case. + padded_shape = shape0; + } + + int rank = padded_shape.dimensions_size(); + std::vector dev_dims; + dev_dims.reserve(rank); + if (rank == 1) { + // Rank 1 tensors might not have padded_shape.layout.minor_to_major set, + dev_dims.push_back(padded_shape.dimensions(0)); + } else { + for (int i = rank - 1; i >= 0; --i) { + int64 dim_index = padded_shape.layout().minor_to_major(i); + dev_dims.push_back(padded_shape.dimensions(dim_index)); + } + } + status->status = tensorflow::Status::OK(); + return new TFE_TensorDebugInfo(dev_dims); + } +#endif // TENSORFLOW_EAGER_USE_XLA + + // If the tensor is not an XLA tensor, the device shape is + // the same as regular tensor shape. + std::vector dev_dims = TensorShapeAsVector(handle, status); + if (!status->status.ok()) { + return nullptr; + } + return new TFE_TensorDebugInfo(dev_dims); +} + +TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( + TFE_TensorDebugInfo* debug_info) { + delete debug_info; +} + +TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( + TFE_TensorDebugInfo* debug_info) { + return debug_info->dev_dims.size(); +} + +TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim( + TFE_TensorDebugInfo* debug_info, int dim_index) { + return debug_info->dev_dims[dim_index]; +} + +} // extern "C" diff --git a/tensorflow/c/eager/c_api_debug_test.cc b/tensorflow/c/eager/c_api_debug_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cddb9f6e00e9d639026f4bbe061d58f76771c0a9 --- /dev/null +++ b/tensorflow/c/eager/c_api_debug_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +TEST(CApiDebug, ScalarCPU) { + TFE_TensorHandle* h = TestScalarTensorHandle(); + TF_Status* status = TF_NewStatus(); + TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + ASSERT_EQ(0, TFE_TensorDebugInfoOnDeviceNumDims(debug_info)); + + TFE_DeleteTensorDebugInfo(debug_info); + TFE_DeleteTensorHandle(h); + TF_DeleteStatus(status); +} + +TEST(CApiDebug, 2DCPU) { + TFE_TensorHandle* h = TestMatrixTensorHandle3X2(); + TF_Status* status = TF_NewStatus(); + TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + ASSERT_EQ(2, TFE_TensorDebugInfoOnDeviceNumDims(debug_info)); + // Shape is the same for CPU tensors. + EXPECT_EQ(3, TFE_TensorDebugInfoOnDeviceDim(debug_info, 0)); + EXPECT_EQ(2, TFE_TensorDebugInfoOnDeviceDim(debug_info, 1)); + + TFE_DeleteTensorDebugInfo(debug_info); + TFE_DeleteTensorHandle(h); + TF_DeleteStatus(status); +} diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 49e1aab1cef9577256d9b081858cf094c788caf8..04a6efc47c5177c82b7e88168b67cc584587de7c 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -28,8 +28,8 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" @@ -37,6 +37,14 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/distributed_runtime/remote_device.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" +#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -51,6 +59,7 @@ struct TFE_ContextOptions { // true if async execution is enabled. bool async = false; TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; + tensorflow::ServerDef server_def; }; struct TFE_Context { @@ -64,6 +73,23 @@ struct TFE_Context { default_policy), async, std::move(device_mgr), rendezvous) {} + explicit TFE_Context( + const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, bool async, + tensorflow::DeviceMgr* local_device_mgr, + tensorflow::Rendezvous* rendezvous, + std::unique_ptr server, + std::unique_ptr remote_eager_workers, + std::unique_ptr remote_device_mgr, + const tensorflow::gtl::FlatMap& + remote_contexts) + : context(opts, + static_cast( + default_policy), + async, local_device_mgr, rendezvous, std::move(server), + std::move(remote_eager_workers), std::move(remote_device_mgr), + remote_contexts) {} + tensorflow::EagerContext context; }; @@ -81,6 +107,14 @@ struct TFE_TensorHandle { tensorflow::TensorHandle* handle; }; +struct TFE_TensorDebugInfo { + TFE_TensorDebugInfo(const std::vector& dims) + : dev_dims(dims) {} + + // Fully-padded, minor-to-major. + std::vector dev_dims; +}; + struct TFE_Op { // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a // primitive operation. diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 701175e4943d1d23532fe595319f67711316ed4d..992d1afd5fcb0641794bb2abbe5ab20a287d3b62 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -23,128 +25,14 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" using tensorflow::string; namespace { -TFE_TensorHandle* DoubleTestMatrixTensorHandle() { - int64_t dims[] = {2, 2}; - double data[] = {1.0, 2.0, 3.0, 4.0}; - TF_Tensor* t = TF_AllocateTensor( - TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_TensorHandle* TestMatrixTensorHandle() { - int64_t dims[] = {2, 2}; - float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_TensorHandle* TestMatrixTensorHandle3X2() { - int64_t dims[] = {3, 2}; - double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { - TF_Status* status = TF_NewStatus(); - - TFE_Op* op = TFE_NewOp(ctx, "MatMul", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, a, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, b, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteStatus(status); - TFE_OpSetAttrBool(op, "transpose_a", 0); - TFE_OpSetAttrBool(op, "transpose_b", 0); - TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); - - return op; -} - -TFE_TensorHandle* TestAxisTensorHandle() { - int64_t dims[] = {1}; - int data[] = {1}; - TF_Tensor* t = TF_AllocateTensor( - TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, - TFE_TensorHandle* axis) { - TF_Status* status = TF_NewStatus(); - - TFE_Op* op = TFE_NewOp(ctx, "Min", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, input, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, axis, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpSetAttrBool(op, "keep_dims", 1); - TFE_OpSetAttrType(op, "Tidx", TF_INT32); - TF_DeleteStatus(status); - TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); - - return op; -} - -// If there is a GPU device, returns true and sets 'gpu_device_name' -// accordingly. -bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); - CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - const int num_devices = TF_DeviceListCount(devices); - for (int i = 0; i < num_devices; ++i) { - const string device_type(TF_DeviceListType(devices, i, status.get())); - CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - const string device_name(TF_DeviceListName(devices, i, status.get())); - CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - if (device_type == "GPU") { - *gpu_device_name = device_name; - LOG(INFO) << "Found GPU device " << device_name; - TF_DeleteDeviceList(devices); - return true; - } - } - TF_DeleteDeviceList(devices); - return false; -} - void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); @@ -220,6 +108,182 @@ TEST(CAPI, Context) { TF_DeleteStatus(status); } +tensorflow::ServerDef GetServerDef(int num_tasks) { + tensorflow::ServerDef server_def; + server_def.set_protocol("grpc"); + server_def.set_job_name("localhost"); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->add_job(); + job_def->set_name("localhost"); + for (int i = 0; i < num_tasks; i++) { + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {i, tensorflow::strings::StrCat("localhost:", port)}); + } + return server_def; +} + +void TestRemoteExecute(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE( + tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + auto* h0_task1 = + TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* h1_task1 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(h0_task1); + TFE_DeleteTensorHandle(h1_task1); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } +TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } + +void TestRemoteExecuteSilentCopies(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE( + tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + + // Handles are on task0, but op is on remote (task1). + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); } +TEST(CAPI, RemoteExecuteSilentCopiesAsync) { + TestRemoteExecuteSilentCopies(true); +} + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); @@ -436,7 +500,7 @@ void TensorHandleSilentCopy(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -483,7 +547,7 @@ void TensorHandleSilentCopyLocal(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -524,7 +588,7 @@ void SetAndGetOpDevices(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_OpSetDevice(matmul, "GPU:0", status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); const char* device_name = TFE_OpGetDevice(matmul, status); @@ -588,7 +652,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TFE_DeleteContextOptions(opts); TFE_TensorHandle* m1 = TestMatrixTensorHandle(); - TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2(); + TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(); TFE_Op* matmul = MatMulOp(ctx, m1, m2); TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0", status); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..5607c9dcb0bbec72b2f86def3dd4e6590d73197b --- /dev/null +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -0,0 +1,163 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_test_util.h" + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +using tensorflow::string; + +TFE_TensorHandle* TestScalarTensorHandle() { + float data[] = {1.0f}; + TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* DoubleTestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "MatMul", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, b, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + TFE_OpSetAttrBool(op, "transpose_a", 0); + TFE_OpSetAttrBool(op, "transpose_b", 0); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; +} + +TFE_TensorHandle* TestAxisTensorHandle() { + int64_t dims[] = {1}; + int data[] = {1}; + TF_Tensor* t = TF_AllocateTensor( + TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Min", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, axis, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrBool(op, "keep_dims", 1); + TFE_OpSetAttrType(op, "Tidx", TF_INT32); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); + + return op; +} + +bool GetDeviceName(TFE_Context* ctx, string* device_name, + const char* device_type) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + const int num_devices = TF_DeviceListCount(devices); + for (int i = 0; i < num_devices; ++i) { + const string dev_type(TF_DeviceListType(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + const string dev_name(TF_DeviceListName(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + if (dev_type == device_type) { + *device_name = dev_name; + LOG(INFO) << "Found " << device_type << " device " << *device_name; + TF_DeleteDeviceList(devices); + return true; + } + } + TF_DeleteDeviceList(devices); + return false; +} diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..474cae67c89249af3a62707f0db00ba458ca8f31 --- /dev/null +++ b/tensorflow/c/eager/c_api_test_util.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ +#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ + +#include "tensorflow/c/eager/c_api.h" + +#include "tensorflow/core/platform/types.h" + +// Return a tensor handle containing a float scalar +TFE_TensorHandle* TestScalarTensorHandle(); + +// Return a tensor handle containing a 2x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle(); + +// Return a tensor handle containing a 2x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle(); + +// Return a tensor handle containing a 3x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(); + +// Return a tensor handle containing a 3x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle3X2(); + +// Return a matmul op multiplying `a` by `b`. +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); + +// Return an 1-D INT32 tensor containing a single value 1. +TFE_TensorHandle* TestAxisTensorHandle(); + +// Return an op taking minimum of `input` long `axis` dimension. +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis); + +// If there is a device of type `device_type`, returns true +// and sets 'device_name' accordingly. +// `device_type` must be either "GPU" or "TPU". +bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name, + const char* device_type); + +#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index e9ed3395c448305bcd6317b0b292b4e4e0b659b1..734e712daa39c03f0177eb199b1acb1b19e5d845 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -48,7 +48,7 @@ struct OpTapeEntry { // Should be called before deleting the backward function. TODO(apassos) use // unique_ptrs to ensure this happens. - std::function backward_function_deleter; + std::function backward_function_deleter; }; // Map from tensor_id to internally-defined operation-id of the operation which @@ -104,14 +104,12 @@ class VSpace { gtl::ArraySlice output_gradients, std::vector* result) const = 0; + // Marks the following gradient as a result so it's not consumed by backward + // functions. + virtual void MarkAsResult(Gradient* gradient) const = 0; + // Deletes the input tensor. virtual void DeleteGradient(Gradient* gradient) const = 0; - - // Lets this VSpace know that it can release resources held by the - // `backward_function`, It will not be called again. - // `backward_function` must not be null. - virtual void ReleaseBackwardFunction( - BackwardFunction* backward_function) const = 0; }; // Traces the execution of operations, doing eager garbage collection, and @@ -126,7 +124,7 @@ class GradientTape { GradientTape(bool persistent) : persistent_(persistent) {} ~GradientTape() { for (const auto& pair : op_tape_) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } @@ -135,12 +133,12 @@ class GradientTape { void Watch(int64 tensor_id); - void RecordOperation(const string& op_type, - gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, - gtl::ArraySlice input_dtypes, - BackwardFunction* backward_function, - const std::function& backward_function_deleter); + void RecordOperation( + const string& op_type, gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, + gtl::ArraySlice input_dtypes, + BackwardFunction* backward_function, + const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); @@ -195,7 +193,9 @@ bool GradientTape::ShouldRecord( CHECK_EQ(tensor_ids.size(), dtypes.size()); for (int i = 0; i < tensor_ids.size(); ++i) { if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { - return IsDtypeTrainable(dtypes[i]); + if (IsDtypeTrainable(dtypes[i])) { + return true; + } } } return false; @@ -212,9 +212,9 @@ void GradientTape::RecordOperation( gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, BackwardFunction* backward_function, - const std::function& backward_function_deleter) { + const std::function& backward_function_deleter) { if (!ShouldRecord(input_tensor_id, input_dtypes)) { - backward_function_deleter(); + backward_function_deleter(backward_function); return; } std::vector ids; @@ -269,7 +269,7 @@ void GradientTape::DeleteTrace(int64 tensor_id) { for (int64 id : op_it->second.input_tensor_id) { DeleteTrace(id); } - op_it->second.backward_function_deleter(); + op_it->second.backward_function_deleter(op_it->second.backward_function); op_tape_.erase(op_it); } @@ -354,8 +354,7 @@ BackpropInitialState PrepareBackprop( count_it->second++; } else { result.tensor_usage_counts[it] = 1; - if (sources_set.find(it) == sources_set.end() && - tensor_tape.find(it) != tensor_tape.end()) { + if (tensor_tape.find(it) != tensor_tape.end()) { tensor_stack.push_back(it); } } @@ -376,7 +375,8 @@ BackpropInitialState PrepareBackprop( // backward functions that will be used for gradient computation // has been transferred to `result`. for (const auto& op_pair : *op_tape) { - op_pair.second.backward_function_deleter(); + op_pair.second.backward_function_deleter( + op_pair.second.backward_function); } op_tape->clear(); } @@ -468,7 +468,7 @@ Status GradientTape::ComputeGradient( if (!persistent_) { // Release all backprop functions for (const auto& pair : state.op_tape) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } }; @@ -520,10 +520,15 @@ Status GradientTape::ComputeGradient( } } else { any_gradient_nonzero = true; - out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); + auto new_gradients = vspace.AggregateGradients(grad_it->second); if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); + } else { + grad_it->second.clear(); + grad_it->second.push_back(new_gradients); + vspace.MarkAsResult(new_gradients); } + out_gradients.push_back(new_gradients); } } std::vector in_gradients; @@ -531,7 +536,7 @@ Status GradientTape::ComputeGradient( Status s = vspace.CallBackwardFunction(trace.backward_function, out_gradients, &in_gradients); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } if (!s.ok()) { cleanup(); @@ -540,7 +545,7 @@ Status GradientTape::ComputeGradient( } else { in_gradients.resize(trace.input_tensor_id.size()); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } for (Gradient* grad : out_gradients) { if (grad != nullptr) { diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh index 02a6a58b6153bb78c684f9290ef95900f96e9357..7184ad68fb79f2598067d68d5ab5ba8f2c7a22c8 100755 --- a/tensorflow/c/generate-pc.sh +++ b/tensorflow/c/generate-pc.sh @@ -15,10 +15,12 @@ # ============================================================================== TF_PREFIX='/usr/local' +LIBDIR='lib' usage() { echo "Usage: $0 OPTIONS" echo -e "-p, --prefix\tset installation prefix (default: /usr/local)" + echo -e "-l, --libdir\tset lib directory (default: lib)" echo -e "-v, --version\tset TensorFlow version" echo -e "-h, --help\tdisplay this message" } @@ -26,7 +28,7 @@ usage() { [ $# == 0 ] && usage && exit 0 # read the options -ARGS=$(getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@") +ARGS=$(getopt -o p:l:v:h --long prefix:,libdir:,version:,help -n $0 -- "$@") eval set -- "$ARGS" # extract options and their arguments into variables. @@ -38,6 +40,11 @@ while true ; do "") shift 2 ;; *) TF_PREFIX=$2 ; shift 2 ;; esac ;; + -l|--libdir) + case "$2" in + "") shift 2 ;; + *) LIBDIR=$2 ; shift 2 ;; + esac ;; -v|--version) case "$2" in "") shift 2 ;; @@ -55,7 +62,7 @@ echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX" cat << EOF > tensorflow.pc prefix=${TF_PREFIX} exec_prefix=\${prefix} -libdir=\${exec_prefix}/lib +libdir=\${exec_prefix}/${LIBDIR} includedir=\${prefix}/include Name: TensorFlow diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d6a4f141b6bb8ccadb77f1fa83b5fb742d78f70f..dfdef88945deca376368edd6f7aa322b1e1cbf94 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -273,6 +273,12 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) { return ""; // Prevent missing return warning } +bool IsEmptyList(const AttrValue::ListValue& list) { + return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 && + list.b_size() == 0 && list.type_size() == 0 && + list.shape_size() == 0 && list.tensor_size() == 0; +} + string ToCamelCase(const string& str) { string result; const char joiner = '_'; @@ -297,9 +303,9 @@ string ToCamelCase(const string& str) { // indicate whether to treat the type as const when accepting the C++ type as an // argument to a function. std::pair AttrTypeName(StringPiece attr_type) { - static const std::unordered_map, - StringPieceHasher> - attr_type_map{ + static const auto* attr_type_map = + new std::unordered_map, + StringPieceHasher>{ {"string", {"StringPiece", false}}, {"list(string)", {"gtl::ArraySlice", true}}, {"int", {"int64", false}}, @@ -317,14 +323,34 @@ std::pair AttrTypeName(StringPiece attr_type) { {"func", {"NameAttrList", true}}, }; - auto entry = attr_type_map.find(attr_type); - if (entry == attr_type_map.end()) { + auto entry = attr_type_map->find(attr_type); + if (entry == attr_type_map->end()) { LOG(FATAL) << "Unsupported Attr type: " << attr_type; return {"", false}; } return entry->second; } +const char* ListElementTypeName(StringPiece attr_type) { + static const auto* attr_list_type_map = + new std::unordered_map{ + {"list(string)", "string"}, + {"list(int)", "int"}, + {"list(float)", "float"}, + {"list(bool)", "bool"}, + {"list(type)", "DataType"}, + {"list(shape)", "PartialTensorShape"}, + {"list(tensor)", "TensorProto"}, + }; + + auto entry = attr_list_type_map->find(attr_type); + if (entry == attr_list_type_map->end()) { + LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type; + return ""; + } + return entry->second; +} + bool IsCPPKeyword(StringPiece name) { static const std::unordered_set // Keywords obtained from http://en.cppreference.com/w/cpp/keyword @@ -668,6 +694,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; + string defaults_static_storage; for (int i = 0; i < graph_op_def.attr_size(); ++i) { const auto& attr(graph_op_def.attr(i)); @@ -705,11 +732,32 @@ string OpInfo::GetOpAttrStruct() const { "_ = x;\n"); strings::StrAppend(&setters, " return ret;\n }\n\n"); - strings::StrAppend( - &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(), - "_ = ", - PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), - ";\n"); + string field_initiliazer; + auto& default_value = api_def_attr.default_value(); + if (default_value.value_case() == AttrValue::kList && + !IsEmptyList(default_value.list())) { + // Non-empty lists need static storage for their defaults. Define a + // function with static local variable that stores the array. + strings::StrAppend(&defaults_static_storage, " static ", + attr_type_name, " Default_", api_def_attr.rename_to(), + "() {\n"); + strings::StrAppend( + &defaults_static_storage, " static const ", + ListElementTypeName(attr.type()), " kStorage[] = ", + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), + ";\n"); + strings::StrAppend(&defaults_static_storage, " return ", + attr_type_name, "(kStorage);\n }\n"); + // Set the field_initializer to call the defined function. + strings::StrAppend(&field_initiliazer, "Default_", + api_def_attr.rename_to(), "()"); + } else { + field_initiliazer = + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()); + } + strings::StrAppend(&struct_fields, " ", attr_type_name, " ", + api_def_attr.rename_to(), "_ = ", field_initiliazer, + ";\n"); } if (struct_fields.empty()) { @@ -721,6 +769,9 @@ string OpInfo::GetOpAttrStruct() const { string struct_decl = MakeComment(attrs_comment, " "); strings::StrAppend(&struct_decl, " struct Attrs {\n"); strings::StrAppend(&struct_decl, setters, struct_fields); + if (!defaults_static_storage.empty()) { + strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage); + } strings::StrAppend(&struct_decl, " };\n"); return struct_decl; diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 52c177212a8c88f1857defcc38de4a01ac47dab0..35a01e0341cb08c9b314908b6dcd76fd99c1e68b 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -38,6 +38,7 @@ REGISTER_NO_GRADIENT_OP("NotEqual"); REGISTER_NO_GRADIENT_OP("LogicalAnd"); REGISTER_NO_GRADIENT_OP("LogicalOr"); REGISTER_NO_GRADIENT_OP("LogicalNot"); +REGISTER_NO_GRADIENT_OP("Floor"); // Conjugate helper function returns the conjugate of an Output if it // is complex valued. diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 1b4c7c2688083e74433da3dce2849b8c37443684..fd7b6fe6625f27bda92e2f56f60908658cdecd7e 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -31,7 +31,6 @@ using ops::AddN; using ops::BatchMatMul; using ops::Const; using ops::Div; -using ops::Greater; using ops::MatMul; using ops::Max; using ops::Maximum; @@ -46,7 +45,6 @@ using ops::RealDiv; using ops::SquaredDifference; using ops::Sub; using ops::Sum; -using ops::Where3; // TODO(andydavis) Test gradient function against numeric gradients output. // TODO(andydavis) As more gradients are added move common test functions diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 0cb3132e94e381f672d69aefe4a199d2b590830c..c73482d5f4d13ade0dc0412941251d1651371b6e 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -255,6 +255,53 @@ Status LRNGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("LRN", LRNGradHelper); +Status SoftplusGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftplusGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softplus", SoftplusGradHelper); + +Status SoftsignGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftsignGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softsign", SoftsignGradHelper); + +Status FractionalAvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalAvgPoolGrad( + scope, Shape(scope, op.input(0), Shape::OutType(DT_INT64)), + grad_inputs[0], op.output(1), op.output(2), + internal::FractionalAvgPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalAvgPool", FractionalAvgPoolGradHelper); + +Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalMaxPoolGrad( + scope, op.input(0), op.output(0), grad_inputs[0], op.output(1), + op.output(2), internal::FractionalMaxPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index c4eba7ecb017fe4628140d75a63bc7f0f09deb7f..b4d457a9d14eb79232cda9412fa0050f6a9968cc 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -28,6 +28,8 @@ namespace { using ops::BiasAdd; using ops::Conv2D; using ops::Elu; +using ops::FractionalAvgPool; +using ops::FractionalMaxPool; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; @@ -41,6 +43,8 @@ using ops::Relu; using ops::Relu6; using ops::Selu; using ops::Softmax; +using ops::Softplus; +using ops::Softsign; class NNGradTest : public ::testing::Test { protected: @@ -71,22 +75,30 @@ class NNGradTest : public ::testing::Test { EXPECT_LT(max_error, 1e-3); } - // Sets tensor with random values, ensuring that the max value is largest by - // a reasonable amount. - // This is an issue for MaxPool, MaxPoolV2 and MaxPool3D, in which - // perturbations by the numeric gradient computation in the gradient checker - // can change the max value if values are too close together. + // Sets tensor with random values, ensuring that every pair of elements are at + // least a reasonable amount apart. + // This is an issue for max pooling operations, in which perturbations by the + // numeric gradient computation in the gradient checker can change the max + // value if a pool has values that are too close together. template - void SetRandomValuesWithBumpedMax(Tensor* tensor) { + void SetRandomValuesForMaxPooling(Tensor* tensor) { auto tensor_flat = tensor->flat(); - tensor_flat.setRandom(); - int32 max_index = 0; - for (size_t i = 1; i < tensor->NumElements(); i++) { - if (tensor_flat(i) > tensor_flat(max_index)) { - max_index = i; - } + // First set the array to an increasing sequence of values spaced + // a reasonable amount apart + T cur = 0; + for (size_t i = 0; i < tensor->NumElements(); i++) { + tensor_flat(i) = cur; + cur += 5e-2; + } + // Fischer-Yates shuffle the array + for (size_t i = tensor->NumElements() - 1; i >= 1; i--) { + // j <- random integer 0 <= j <= i + size_t j = random::New64() % (i + 1); + // swap values at i, j + T tmp = tensor_flat(i); + tensor_flat(i) = tensor_flat(j); + tensor_flat(j) = tmp; } - tensor_flat(max_index) += 1e-2; } Scope scope_; @@ -189,7 +201,7 @@ TEST_F(NNGradTest, MaxPoolGradHelper) { const std::vector strides{1, 2, 2, 1}; auto y = MaxPool(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -202,7 +214,7 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { Tensor strides = test::AsTensor({1, 2, 2, 1}, {4}); auto y = MaxPoolV2(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -215,7 +227,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) { const std::vector strides{1, 3, 3, 3, 1}; auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -248,5 +260,45 @@ TEST_F(NNGradTest, LRN){ RunTest(x, x_shape, y, x_shape); } +TEST_F(NNGradTest, SoftplusGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softplus(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, SoftsignGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softsign(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, FractionalAvgPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalAvgPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalAvgPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_shape, y.output, y_shape); +} + +TEST_F(NNGradTest, FractionalMaxPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalMaxPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalMaxPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesForMaxPooling(&x_init_value); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_init_value, y.output, y_shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index 4ddddcb5863c9ffb1e5367db750b0d2ffd29cd5e..23e9dc40d23899b9cef168c9128b6d8ed1be3ee9 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/tools/freeze_saved_model.h" +#include #include #include "tensorflow/core/framework/attr_value.pb.h" @@ -71,6 +72,15 @@ void GetNodeNameToNodeDefMap( } } +// Strips off the tensor part of the tensor_name to get the node_name. +const string GetNodeNameFromTensorName(string tensor_name) { + if (tensor_name[0] == '^') { + tensor_name.erase(0, 1); + } + std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); + return tensor_name_parts[0]; +} + // Gets the set of node names needed by `outputs` and the corresponding set of // variable nodes to convert. void GetReachableNodesAndVariables( @@ -83,10 +93,8 @@ void GetReachableNodesAndVariables( new std::unordered_set({"Variable", "VariableV2", "VarHandleOp"}); std::queue nodes_to_visit; - for (const string& tensor_name : outputs) { - // We need to strip off the tensor part to get the node name. - std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); - nodes_to_visit.push(tensor_name_parts[0]); + for (const string& output_tensor_name : outputs) { + nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name)); } // We do a traversal backwards from the outputs specified in the MetaGraphDef. while (!nodes_to_visit.empty()) { @@ -100,8 +108,8 @@ void GetReachableNodesAndVariables( if (kVariableTypes->find(node->op()) != kVariableTypes->end()) { variable_node_names->insert(node->name()); } - for (const string& input : node->input()) { - nodes_to_visit.push(input); + for (const string& input_tensor_name : node->input()) { + nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name)); } } } diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index cd35fd3b95deec669218cfa4f25fea2c3ac9e56e..979b23c3fc5f66ec574736cb4d39cec0ffd8e6b6 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -351,6 +351,56 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) { GraphDefEqual(frozen_graph_def, graph_def); } +TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) { + // Tensors from operations with multiple outputs get tensor suffixes when used + // in input fields of following nodes, i.e. split:0, split:1. + // Test that we traverse those correctly. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2}); + Output axis = ops::Const(scope.WithOpName("axis"), 0, {}); + OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output; + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), split[1], b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + +TEST_F(FreezeTest, GraphDefWithControlDependency) { + // Inputs that are control dependencies get tensor prefixes, + // i.e. ^control_dependency. + // Test that we traverse those correctly. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output source = ops::Const(scope.WithOpName("source"), 10.0f, {}); + Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source), + {10.0f, 10.0f}, {2}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + TEST_F(FreezeTest, GraphDefWithoutDependentVariables) { TestFreezeGraphWithoutDependentVariables(false); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 19e6bf68e77725bb3cae4e1d338c52dff472cb18..2119c8ec47f941a76e81346ae5d20da78eae11a3 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -214,7 +214,6 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@llvm//:core", - "@llvm//:execution_engine", "@llvm//:support", "@llvm//:target", ], diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 6e050cf56494e6d26e3647e3261a657eeaad64fa..6641d45e83020f4144616a6a2837c844330298f5 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -56,9 +56,9 @@ namespace bar { // // Memory stats: // arg bytes total: 104 -// arg bytes aligned: 128 +// arg bytes aligned: 192 // temp bytes total: 126 -// temp bytes aligned: 224 +// temp bytes aligned: 320 class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 63d22de1ca4aa0872b6fad3e0ac0182306d7cb8c..4e27aafec7747655d8e4ea3ddd1788d495ca0710 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -82,7 +82,8 @@ static StatusOr CodegenModule(llvm::TargetMachine* target_machine, llvm::legacy::PassManager codegen_passes; if (target_machine->addPassesToEmitFile( - codegen_passes, ostream, llvm::TargetMachine::CGFT_ObjectFile)) { + codegen_passes, ostream, nullptr, + llvm::TargetMachine::CGFT_ObjectFile)) { return xla::InternalError( "Could not create pass pipeline to generate object file"); } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index ebfe4806c203e901358d5c5096c10c03d4c738c3..4e194a6aba9a9efcad27c47c42e148d8e537ae68 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -71,7 +71,7 @@ struct ProtobufToEmbed { const ::tensorflow::protobuf::MessageLite* message; }; -// Embeds a a sequence of protocol buffers into an object file. +// Embeds a sequence of protocol buffers into an object file. // // `target_triple` is the target triple for the target architecture for the // generated object file. diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/aot/runtime.h index d085864f0012e4de55685bb46961417bb3070e6f..d1a669ceb17b9fd71d26e978035283f8824b0376 100644 --- a/tensorflow/compiler/aot/runtime.h +++ b/tensorflow/compiler/aot/runtime.h @@ -25,8 +25,8 @@ namespace tensorflow { namespace tfcompile { namespace runtime { -// Align to 32-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. -static constexpr size_t kAlign = 32; +// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. +static constexpr size_t kAlign = 64; // aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1 // values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc index 6d603a02eb4ceade6832ba67b2981814ee25327a..06ec623eb2dce5f8dc7156fb7e7b9ad57d90c8ee 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/aot/runtime_test.cc @@ -24,7 +24,7 @@ namespace runtime { namespace { TEST(Runtime, AlignmentValue) { - // We've chosen 32 byte alignment for the tfcompile runtime to mimic the + // We've chosen 64 byte alignment for the tfcompile runtime to mimic the // regular tensorflow allocator, which was chosen to play nicely with Eigen. // The tfcompile runtime also has a requirement that comes from the xla // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 @@ -39,13 +39,13 @@ TEST(Runtime, AlignedBufferBytes) { EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0); static constexpr intptr_t sizesB[1] = {3}; - EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 32); + EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64); static constexpr intptr_t sizesC[1] = {32}; - EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 32); + EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64); static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; - EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 192); + EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320); } void* add_ptr(void* base, uintptr_t delta) { @@ -101,11 +101,11 @@ TEST(Runtime, MallocFreeContiguousBuffers) { EXPECT_NE(base, nullptr); EXPECT_EQ(bufD[0], add_ptr(base, 0)); EXPECT_EQ(bufD[1], nullptr); - EXPECT_EQ(bufD[2], add_ptr(base, 32)); + EXPECT_EQ(bufD[2], add_ptr(base, 64)); EXPECT_EQ(bufD[3], nullptr); - EXPECT_EQ(bufD[4], add_ptr(base, 64)); - EXPECT_EQ(bufD[5], add_ptr(base, 128)); - EXPECT_EQ(bufD[6], add_ptr(base, 160)); + EXPECT_EQ(bufD[4], add_ptr(base, 128)); + EXPECT_EQ(bufD[5], add_ptr(base, 192)); + EXPECT_EQ(bufD[6], add_ptr(base, 256)); for (int i = 0; i < 7; ++i) { const intptr_t size = sizesD[i]; if (size != -1) { diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index fd2cf2b67d4618dd626b8eef78eed044d7fde0a4..0ecc3feeb6fef1dd691ab2785b3221075a79ba88 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -7,6 +7,10 @@ package( load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +# We disable some tfcompile tests in the open source build with the +# "manual" tag to avoid making our OSS users build LLVM twice +# (once for host and once for target). + test_suite( name = "all_tests", tags = ["manual"], diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 309a991fc11ab74ddd58a6345d9d40ad84fb2734..fee46280e9a0e7ba2cf7c3ed46469ae8cc0841d4 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -40,7 +40,7 @@ namespace tfcompile { namespace { using ::testing::HasSubstr; -using ::testing::UnorderedElementsAre; +using ::testing::IsSupersetOf; TEST(TFCompileTest, Add) { AddComp add; @@ -551,25 +551,20 @@ TEST(TFCompileTest, HloProfiling) { auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); auto dot_profile_line = HasSubstr( - "%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " "%arg1.0.1)"); auto add_profile_line = HasSubstr( - "%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " "%arg1.0.1)"); auto tuple_profile_line = HasSubstr( "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " - "%dot.0.2, f32[2,2]{1,0} %add.0.5)"); + "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); - hlo_profile_lines.erase(hlo_profile_lines.begin() + 7, - hlo_profile_lines.end()); - - EXPECT_THAT( - hlo_profile_lines, - UnorderedElementsAre(header, total_cycles_profile_line, dot_profile_line, - add_profile_line, tuple_profile_line, - arg0_profile_line, arg1_profile_line)); + EXPECT_THAT(hlo_profile_lines, + IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, + add_profile_line, tuple_profile_line})); } } // namespace diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a6b3ce394c6859c4f45bbde4e39dde9229da3388..51a79e2cd9604b11d4411b9acab4fbea13282469 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -25,6 +25,7 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( @@ -124,7 +125,6 @@ cc_library( srcs = ["xla_tensor.cc"], hdrs = ["xla_tensor.h"], deps = [ - ":common", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:shaped_buffer", @@ -176,11 +176,13 @@ cc_library( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", + "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", - "@com_google_absl//absl/memory", ], ) @@ -217,6 +219,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:gpu_runtime", @@ -272,7 +275,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -293,7 +295,6 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "@com_google_absl//absl/memory", ], ) @@ -313,6 +314,7 @@ cc_library( ":common", ":shape_inference_helpers", ":union_find", + ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags", @@ -333,6 +335,19 @@ cc_library( ], ) +cc_library( + name = "xla_cluster_util", + srcs = ["xla_cluster_util.cc"], + hdrs = ["xla_cluster_util.h"], + deps = [ + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:bounds_check", + ], +) + cc_library( name = "union_find", hdrs = ["union_find.h"], @@ -409,6 +424,38 @@ tf_cc_test( ], ) +cc_library( + name = "xla_fusion_optimizer", + srcs = ["xla_fusion_optimizer.cc"], + hdrs = ["xla_fusion_optimizer.h"], + visibility = ["//visibility:public"], + deps = [ + ":common", + ":union_find", + ":xla_cluster_util", + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ], +) + +tf_cuda_cc_test( + name = "xla_fusion_optimizer_test", + srcs = ["xla_fusion_optimizer_test.cc"], + deps = [ + ":common", + ":xla_cluster_util", + ":xla_fusion_optimizer", + "//tensorflow/core:graph", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc index 9a2bb0007527557f79b70ad2b9c9576af2ab10ea..b17ff589e2597f8d1b5e61f4eaaed7d6ebe6214c 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -40,7 +40,7 @@ static Status BuildLaunchNode( Graph* graph, Node** node) { NodeDef def; def.set_name(graph->NewName(nodename)); - def.set_op("_XlaLaunch"); + def.set_op("XlaLaunch"); def.set_device(device_name); AddNodeAttr("Tconstants", constant_dtypes, &def); AddNodeAttr("Targs", arg_dtypes, &def); @@ -79,7 +79,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { node->input_types().begin() + num_constant_args, node->input_types().begin() + num_constant_args + num_nonconst_args); - // Build a _XlaLaunch operator to execute the function body. + // Build a XlaLaunch operator to execute the function body. Node* launch_node; TF_RETURN_IF_ERROR(BuildLaunchNode( graph->NewName(node->name()), node->type_string(), node->def().attr(), diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index f35e916eb937faf7e1afd53a4a5dfdb95a8bbe43..731b8ebfdc6262500940274c94a03ae7c0376096 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/create_xla_launch_op.h" -#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -23,6 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace { @@ -203,8 +203,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - *kernel = absl::make_unique( - &construction, constant_arg_indices, resource_arg_indices, function); + *kernel = MakeUnique(&construction, constant_arg_indices, + resource_arg_indices, function); return s; } diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc index bcd5e75c7e4c021a9be874ed96e994768bb80811..b75ab486b80e098bc0a59f9ea8cdbaa23a28fef9 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/jit/create_xla_launch_op.h" -#include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function_testlib.h" @@ -25,6 +24,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -65,11 +65,11 @@ class CreateXlaLaunchOpTest : public ::testing::Test { for (const auto& fdef : flib) { *(proto.add_function()) = fdef; } - lib_def_ = absl::make_unique( - OpRegistry::Global(), proto); + lib_def_ = + MakeUnique(OpRegistry::Global(), proto); OptimizerOptions opts; - device_mgr_ = absl::make_unique(devices_); - pflr_ = absl::make_unique( + device_mgr_ = MakeUnique(devices_); + pflr_ = MakeUnique( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 34be4409a381197d2191e083727aa8d48ab8cd63..5fee36f022a7515504cb6faa5cca658481b784c5 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -80,7 +80,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr* graph_out, FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate -// subgraphs pass and that should in turn be compiled via _XlaLaunch operators. +// subgraphs pass and that should in turn be compiled via XlaLaunch operators. extern const char* const kXlaCompiledKernelAttr; // Does `node` have the kXlaCompiledKernelAttr attribute? diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 5ec24d39a2c40a766dbb0ec51ebe798de620e24b..eef113a3547f0b2f648680d5f51650f70dbbd261 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -1050,7 +1050,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { .WithAttr("_outside", "O1")); Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, shape2.opts()); - Node* h = Binary(ops::NodeOut(recv2, 0), e, + Node* h = Binary(ops::NodeOut(recv2, 1), e, shape2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") @@ -1075,7 +1075,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", - {"D:o:0", "F:o:0"}, + {"F:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"ancestors", @@ -1123,13 +1123,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, b2.opts()); - Node* g = Binary(e, ops::NodeOut(recv2, 1), + Node* g = Binary(e, ops::NodeOut(recv2, 0), b2.opts() .WithName("G") .WithControlInputs({recv2, e}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - Node* h = Binary(ops::NodeOut(recv2, 0), e, + Node* h = Binary(ops::NodeOut(recv2, 1), e, b2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index bc68afb322b5cfc814ce0537254ba14053ae4550..805bbc62c1e2e877de87ab8faf3d60b829743df8 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -354,6 +354,16 @@ bool GraphCycles::IsReachableNonConst(int32 x, int32 y) { return reachable; } +bool GraphCycles::CanContractEdge(int32 a, int32 b) { + CHECK(HasEdge(a, b)) << "No edge exists from " << a << " to " << b; + RemoveEdge(a, b); + bool reachable = IsReachableNonConst(a, b); + // Restore the graph to its original state. + InsertEdge(a, b); + // If reachable, then contracting edge will cause cycle. + return !reachable; +} + bool GraphCycles::ContractEdge(int32 a, int32 b) { CHECK(HasEdge(a, b)); RemoveEdge(a, b); @@ -388,4 +398,8 @@ std::unordered_set GraphCycles::Successors(int32 node) { return rep_->nodes_[node]->out; } +std::unordered_set GraphCycles::Predecessors(int32 node) { + return rep_->nodes_[node]->in; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/jit/graphcycles/graphcycles.h index d11d6e27b1b7bb514127e16a9be21f044100d885..44448fa3d787d0785a797d40ed1b968438a903c9 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.h +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.h @@ -85,6 +85,9 @@ class GraphCycles { // and returns false. bool ContractEdge(int32 a, int32 b); + // Return true if can contract edge, otherwise return false. + bool CanContractEdge(int32 a, int32 b); + // Return whether dest_node is reachable from source_node // by following edges. bool IsReachable(int32 source_node, int32 dest_node) const; @@ -115,6 +118,7 @@ class GraphCycles { bool CheckInvariants() const; std::unordered_set Successors(int32 node); + std::unordered_set Predecessors(int32 node); // ---------------------------------------------------- struct Rep; diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc index e47b782207e9122740fe9d5daf1fa0dbaeb47754..274f5938a1228baf68ad4d8e1a7b13f276321d27 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc @@ -494,6 +494,20 @@ TEST_F(GraphCyclesTest, ContractEdge) { EXPECT_TRUE(g_.HasEdge(1, 4)); } +TEST_F(GraphCyclesTest, CanContractEdge) { + ASSERT_TRUE(AddEdge(1, 2)); + ASSERT_TRUE(AddEdge(1, 3)); + ASSERT_TRUE(AddEdge(2, 3)); + ASSERT_TRUE(AddEdge(2, 4)); + ASSERT_TRUE(AddEdge(3, 4)); + + EXPECT_FALSE(g_.CanContractEdge(1, 3)); + EXPECT_FALSE(g_.CanContractEdge(2, 4)); + EXPECT_TRUE(g_.CanContractEdge(1, 2)); + EXPECT_TRUE(g_.CanContractEdge(2, 3)); + EXPECT_TRUE(g_.CanContractEdge(3, 4)); +} + static void BM_StressTest(int iters, int num_nodes) { while (iters > 0) { tensorflow::GraphCycles g; diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 86a9fd3b8e124e581bc4b73f264dbd5be46c790a..902fe27acdec1cb323217e6409fbd02f62177612 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -112,7 +112,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); - const XlaDevice::Metadata* metadata; + const XlaDevice::Metadata* metadata = nullptr; Status s = XlaDevice::GetMetadata(ctx, &metadata); bool allocate_xla_tensors = s.ok(); @@ -148,14 +148,14 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; - options.device_type = &cache->device_type(); + options.device_type = cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); options.device_allocator = xla_allocator; - // TODO(b/77671268): We don't set variable_representation_shape_fn here. This - // is restricted to Variables, but we need something like this to apply to - // normal Tensors too. + if (metadata) { + options.shape_representation_fn = metadata->shape_representation_fn(); + } const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; @@ -164,9 +164,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { for (int i : constants_) { constant_args.insert({i, ctx->input(i)}); } - OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, - variables, ctx, &kernel, &executable, - /*compile_options=*/nullptr)); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + OP_REQUIRES_OK( + ctx, cache->Compile(options, function_, constant_args, variables, ctx, + &kernel, &executable, &compile_options)); VLOG(1) << "Executing XLA Computation..."; @@ -254,10 +256,9 @@ XlaLocalLaunchOp::~XlaLocalLaunchOp() { VLOG(1) << "XlaLocalLaunchOp destroyed"; } -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU), - XlaLocalLaunchOp); +REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") +REGISTER_KERNEL_BUILDER(Name("XlaLaunch") .Device(DEVICE_GPU) .HostMemory("constants") .HostMemory("resources"), diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 8e2ee0f1d71bc17b4c12c792c38002af4f9eb5eb..8c3882116dd4f048ea3e32c037bf4139c67a3eb9 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -41,9 +42,6 @@ limitations under the License. namespace tensorflow { -const char* const kXlaClusterAttr = "_XlaCluster"; -const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; - namespace { bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { @@ -60,6 +58,14 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { return false; } } + + // XLA does not offer guaranteed aliasing between the input and output of the + // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave + // such nodes out of XLA clusters. + if (HasForwardedRefInput(node)) { + return false; + } + return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } @@ -165,16 +171,6 @@ bool IsCompilableCall(const NodeDef& call_def, return true; } -// Returns the DeviceType corresponding to 'device'. -Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) { - DeviceNameUtils::ParsedName parsed; - if (!DeviceNameUtils::ParseFullName(device, &parsed)) { - return errors::Internal("Malformed assigned device '", device, "'"); - } - *device_type = DeviceType(parsed.type); - return Status::OK(); -} - // Tests whether `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node) { return std::find(node.input_types().begin(), node.input_types().end(), @@ -183,18 +179,11 @@ bool HasResourceInputOrOutput(const Node& node) { DT_RESOURCE) != node.output_types().end(); } -struct NodeCompare { - bool operator()(const Node* a, const Node* b) const { - return a->id() < b->id(); - } -}; -using OrderedNodeSet = std::set; - // Returns true if the op can be decomposed into XLA ops for which // there are fusable elemental implementations. // -// TODO(hpucha): Consider a black list instead of a white list as -// implemented below. +// TODO(hpucha): Remove this code since this functionality is subsumed by +// Grappler XlaFusionOptimizer. bool IsXlaFusable(const NodeDef& node) { static const std::unordered_set* elementwise_ops = new std::unordered_set( @@ -364,7 +353,7 @@ Status FindCompilationCandidates( for (Node* node : graph.op_nodes()) { sorted_nodes.push_back(node); } - std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare()); + std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); for (Node* node : sorted_nodes) { VLOG(2) << "Fuel: " << fuel; @@ -379,9 +368,13 @@ Status FindCompilationCandidates( DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(node->assigned_device_name(), &device_type)); + DeviceToDeviceType(node->assigned_device_name(), &device_type)); - if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue; + if (is_compilable_fn && !is_compilable_fn(node, device_type)) { + VLOG(2) << "Compilation rejected node: not compilable " << node->name() + << ": " << node->type_string(); + continue; + } const XlaOpRegistry::DeviceRegistration* registration; CHECK( @@ -430,46 +423,6 @@ struct Cluster { int representative = -1; }; -// Returns a string describing how an edge from src to dst would -// create a cycle. -string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src, - int dst) { - int32 max_path_size = graph.num_node_ids() + 1; - std::vector path(max_path_size); - int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data()); - if (path_size == 0) { - return ""; - } - - auto node_name = [&cycles, &graph](int node_id) { - if (!FastBoundsCheck(node_id, graph.num_node_ids())) { - return string("(null)"); - } - auto* node = graph.FindNodeId(node_id); - if (node == nullptr) { - return string("(null)"); - } - return node->name(); - }; - - string description; - strings::StrAppend(&description, "Edge from ", node_name(src), " to ", - node_name(dst), " would create a cycle.\n"); - path.resize(path_size); - for (int32 node_id : path) { - string ascii_art; - if (node_id == dst) { - ascii_art = "+-> "; - } else if (node_id != src) { - ascii_art = "| "; - } else { - ascii_art = "+-- "; - } - strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); - } - return description; -} - } // anonymous namespace bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { @@ -575,84 +528,13 @@ Status MarkForCompilationPass::RunImpl( : Env::Default(), is_compilable_fn, &compilation_candidates)); - GraphCycles cycles; - for (int i = 0; i < graph->num_node_ids(); ++i) { - // We rely on the node IDs in the cycle detection graph being consecutive - // integers starting from 0. - CHECK_EQ(i, cycles.NewNode()); + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + return Status::OK(); } - // Compute the loop structure of the graph. - std::vector control_flow_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); - - // The clustering code must avoid adding cycles to the graph to prevent - // deadlock. However, the graph may contain loops, which would trigger the - // cycle detection code. To handle loops, we alter the structure of the cycle - // detection graph, disconnecting each loop from the enclosing graph. - // Specifically, we: - // * add a new "frame" node for each loop. - // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges - // to/from the corresponding frame node. In essence, we collapse the loop - // into a single node for the purpose of cycle detection in the enclosing - // graph. - // * the body of the loop should now be disconnected from the rest of the - // graph; we make it acyclic by breaking loop backedges (edges outgoing from - // "NextIteration" nodes. - - // Map from frame name strings to node IDs in the cycle detection graph. - std::unordered_map frame_nodes; - - // Get the cycle graph node ID for frame 'frame_name', or add one if none - // exists. - auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) { - int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; - if (frame_id < 0) { - // The emplace succeeded; we have not allocated a frame node yet. - frame_id = cycles.NewNode(); - } - return frame_id; - }; - - for (Edge const* edge : graph->edges()) { - if (edge->dst()->IsEnter()) { - // Lift edges to an "Enter" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->dst()->id()].frame_name; - int dst = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(edge->src()->id(), dst)) { - return errors::Internal( - "Cycle detected when adding enter->frame edge: ", - DescribeCycle(cycles, *graph, edge->src()->id(), dst)); - } - continue; - } - if (edge->src()->IsExit()) { - // Lift edges from an "Exit" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->src()->id()].frame_name; - int src = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(src, edge->dst()->id())) { - return errors::Internal( - "Cycle detected when adding frame->exit edge: ", - DescribeCycle(cycles, *graph, src, edge->dst()->id())); - } - // Drop the original edge. - continue; - } - if (edge->src()->IsNextIteration()) { - // Break loop back-edges. - continue; - } - if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) { - // This should never happen. All cycles in the graph should contain - // a control flow operator. - return errors::Internal( - "Found cycle in graph without control flow operator during XLA " - "compilation: ", - DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); - } - } + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); // Each compilation candidate belongs to a cluster. The cluster's // representative @@ -670,6 +552,9 @@ Status MarkForCompilationPass::RunImpl( // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. + // + // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for + // example, from the Grappler fusion pass). while (!worklist.empty()) { int from = worklist.front()->Get().representative; worklist.pop_front(); @@ -778,7 +663,7 @@ Status MarkForCompilationPass::RunImpl( // compilation. DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(n->assigned_device_name(), &device_type)); + DeviceToDeviceType(n->assigned_device_name(), &device_type)); const XlaOpRegistry::DeviceRegistration* registration; XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 703d8825d74ced8d4d69c31ccd730adc89a8bffe..772c92d369e67f431b5d030d1d5cdc5ae2700d39 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -633,5 +633,52 @@ TEST(XlaCompilationTest, ConstOp) { } } +TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output add = ops::Add(root.WithOpName("add"), neg, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + +TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output identity = ops::Negate(root.WithOpName("identity"), neg); + Output add = ops::Add(root.WithOpName("add"), identity, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, + {"identity", cluster_name}, + {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 07320b43dab790e6cda5e85688bdacf48a35adc4..f2473d98ffd5dae55983e601b8d2d65af6a6d54c 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -17,7 +17,7 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("_XlaLaunch") +REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") .Input("args: Targs") @@ -28,7 +28,7 @@ REGISTER_OP("_XlaLaunch") .Attr("Tresults: list(type) >= 0") .Attr("function: func") // XLA random-number generation ops are stateful. - // TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch. + // TODO(phawkins): create stateful and non-stateful variants of XlaLaunch. .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..05b7821b8865d0f210ca9af92370e177d6043e80 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_cluster_util.h" + +#include + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +const char* const kXlaClusterAttr = "_XlaCluster"; +const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; + +namespace { +// Returns a string describing how an edge from src to dst would +// create a cycle. +string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, + int dst) { + int32 max_path_size = graph.num_node_ids() + 1; + std::vector path(max_path_size); + int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data()); + if (path_size == 0) { + return ""; + } + + auto node_name = [cycles, &graph](int node_id) { + if (!FastBoundsCheck(node_id, graph.num_node_ids())) { + return string("(null)"); + } + auto* node = graph.FindNodeId(node_id); + if (node == nullptr) { + return string("(null)"); + } + return node->name(); + }; + + string description; + strings::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); + path.resize(path_size); + for (int32 node_id : path) { + string ascii_art; + if (node_id == dst) { + ascii_art = "+-> "; + } else if (node_id != src) { + ascii_art = "| "; + } else { + ascii_art = "+-- "; + } + strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); + } + return description; +} + +bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); } + +} // namespace + +Status DeviceToDeviceType(const string& device, DeviceType* device_type) { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(device, &parsed)) { + return errors::Internal("Malformed assigned device '", device, "'"); + } + *device_type = DeviceType(parsed.type); + return Status::OK(); +} + +bool HasForwardedRefInput(const Node& node) { + if (AlwaysForwardsRefInput(node)) { + for (const Edge* incoming_edge : node.in_edges()) { + if (incoming_edge->IsControlEdge()) { + continue; + } + + Node* incoming_node = incoming_edge->src(); + if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) { + VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input " + << incoming_node->name() << " " << incoming_node->type_string(); + return true; + } + } + } + return false; +} + +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { + for (int i = 0; i < graph->num_node_ids(); ++i) { + // We rely on the node IDs in the cycle detection graph being consecutive + // integers starting from 0. + CHECK_EQ(i, cycles->NewNode()); + } + + // Compute the loop structure of the graph. + std::vector control_flow_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); + + // The clustering code must avoid adding cycles to the graph to prevent + // deadlock. However, the graph may contain loops, which would trigger the + // cycle detection code. To handle loops, we alter the structure of the cycle + // detection graph, disconnecting each loop from the enclosing graph. + // Specifically, we: + // * add a new "frame" node for each loop. + // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges + // to/from the corresponding frame node. In essence, we collapse the loop + // into a single node for the purpose of cycle detection in the enclosing + // graph. + // * the body of the loop should now be disconnected from the rest of the + // graph; we make it acyclic by breaking loop backedges (edges outgoing from + // "NextIteration" nodes. + + // Map from frame name strings to node IDs in the cycle detection graph. + std::unordered_map frame_nodes; + + // Get the cycle graph node ID for frame 'frame_name', or add one if none + // exists. + auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) { + int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; + if (frame_id < 0) { + // The emplace succeeded; we have not allocated a frame node yet. + frame_id = cycles->NewNode(); + } + return frame_id; + }; + + for (Edge const* edge : graph->edges()) { + if (edge->dst()->IsEnter()) { + // Lift edges to an "Enter" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->dst()->id()].frame_name; + int dst = GetOrAddFrameNodeId(frame_name); + if (!cycles->InsertEdge(edge->src()->id(), dst)) { + return errors::Internal( + "Cycle detected when adding enter->frame edge: ", + DescribeCycle(cycles, *graph, edge->src()->id(), dst)); + } + continue; + } + if (edge->src()->IsExit()) { + // Lift edges from an "Exit" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->src()->id()].frame_name; + int src = GetOrAddFrameNodeId(frame_name); + if (!cycles->InsertEdge(src, edge->dst()->id())) { + return errors::Internal( + "Cycle detected when adding frame->exit edge: ", + DescribeCycle(cycles, *graph, src, edge->dst()->id())); + } + // Drop the original edge. + continue; + } + if (edge->src()->IsNextIteration()) { + // Break loop back-edges. + continue; + } + if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) { + // This should never happen. All cycles in the graph should contain + // a control flow operator. + return errors::Internal( + "Found cycle in graph without control flow operator during XLA " + "compilation: ", + DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h new file mode 100644 index 0000000000000000000000000000000000000000..bcce082aaf6044ff0654efa4d78c0f493a350d00 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains utilities for clustering compilable graph nodes via XLA. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +// The attribute that marks nodes to be grouped into functions by the +// encapsulate subgraphs pass. +extern const char* const kXlaClusterAttr; + +// The attribute that marks nodes in a cluster to be placed outside the xla +// compilation by the encapsulate subgraphs pass. +extern const char* const kXlaOutsideCompilationAttr; + +using OrderedNodeSet = std::set; + +// Returns the DeviceType corresponding to 'device'. +Status DeviceToDeviceType(const string& device, DeviceType* device_type); + +// Returns true if `node` has a ref tensor input that it forwards to its output. +bool HasForwardedRefInput(const Node& node); + +// Creates a graph representation to enable cycle detection when clustering. +// This representation handles loops in graph by disconnecting each loop from +// the enclosing graph. +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 6430975335f5eef5b53c80213e6090ffd6166a91..7ed609c43748062656b631243c01d790519c54fd 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -122,8 +122,7 @@ Status XlaCompilationCache::BuildSignature( namespace { -// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch -// op. +// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op. Status BuildArguments(const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 6b83cf67ffc571f235ae84d0de58254c5d7e4962..b1943d3e1a7e321b5a3796a0c6e4f2b5d9ac7018 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -151,16 +151,18 @@ Status XlaCompileOnDemandOp::Compile( core::ScopedUnref cache_ref(cache); XlaCompiler::Options options; - DeviceType device_type = metadata.jit_device_type(); - options.device_type = &device_type; + options.device_type = metadata.jit_device_type(); options.client = metadata.client(); options.flib_def = new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); + options.shape_representation_fn = metadata.shape_representation_fn(); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, - /*compile_options=*/nullptr); + result, executable, &compile_options); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index 23c6f3903f841a6c39104983c6f7f409757a7319..7cc3d0e007ba2974fbfbe6fbabc4aa08f9fa910f 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -29,11 +29,8 @@ limitations under the License. namespace tensorflow { // An OpKernel that compiles an op to an XLA computation and runs it. Unlike -// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a +// XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a // vanilla TensorFlow op as long as the bridge supports it. -// -// Importantly _XlaLaunch assumes all input and output tensors are on the host, -// whereas XlacompileOnDemandOp works with tensors in device memory. class XlaCompileOnDemandOp : public OpKernel { public: explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index bc07dbd7bdf005fde781f7a1e6775080e363abfb..43648402f65c656b6b4eb2e83e61ce45f1c73669 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -53,7 +53,9 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, name_prefix, registration, - /*transfer_as_literal=*/false, &device)); + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 70263b1ff936757101a3c47d192b2ba58271dc79..ed007d603ea1b3d27dd25f00726261cdd029c20c 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/ptr_util.h" #include "tensorflow/core/util/stream_executor_util.h" namespace tensorflow { @@ -105,12 +106,33 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( return alloc_ptr; } +namespace { + +// Default PaddedShapeFn implementation that simply returns the unpadded +// on-device shape. This is accurate for CPU and GPU devices that neither +// transpose nor pad tensors. +Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { + const tensorflow::XlaTensor* xla_tensor = + tensorflow::XlaTensor::FromTensor(&tensor); + if (xla_tensor == nullptr) { + return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape); + } + + const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); + *shape = shaped_buffer.on_device_shape(); + return Status::OK(); +} + +} // namespace + /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, std::unique_ptr* device) { + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; @@ -129,17 +151,22 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), strings::StrCat("device: ", device_name, " device")); - device->reset(new XlaDevice(options, attrs, device_ordinal, - DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal)); + device->reset(new XlaDevice( + options, attrs, device_ordinal, DeviceType(jit_device_name), + platform.ValueOrDie(), transfer_as_literal, shape_representation_fn, + padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); return Status::OK(); } -XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type) +XlaDevice::Metadata::Metadata( + int device_ordinal, se::Platform* platform, const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + PaddedShapeFn padded_shape_fn) : device_ordinal_(device_ordinal), device_type_(device_type), - platform_(platform) {} + platform_(platform), + shape_representation_fn_(std::move(shape_representation_fn)), + padded_shape_fn_(std::move(padded_shape_fn)) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -170,17 +197,21 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return Status::OK(); } -XlaDevice::XlaDevice(const SessionOptions& options, - const DeviceAttributes& attrs, int device_ordinal, - const DeviceType& jit_device_name, se::Platform* platform, - bool transfer_as_literal) +XlaDevice::XlaDevice( + const SessionOptions& options, const DeviceAttributes& attrs, + int device_ordinal, const DeviceType& jit_device_name, + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn) : LocalDevice(options, attrs), - xla_metadata_(device_ordinal, platform, jit_device_name), + xla_metadata_(device_ordinal, platform, jit_device_name, + shape_representation_fn, padded_shape_fn), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), platform_(platform), - transfer_as_literal_(transfer_as_literal) { + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(shape_representation_fn) { VLOG(1) << "Created XLA device " << jit_device_name; } @@ -230,10 +261,10 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() { GetAllocator({}); // XlaDevice owns both gpu_device_info_ and // gpu_device_info_->default_context. - gpu_device_info_ = absl::make_unique(); + gpu_device_info_ = MakeUnique(); gpu_device_info_->stream = stream; - gpu_device_info_->default_context = - new XlaDeviceContext(stream, client(), transfer_as_literal_); + gpu_device_info_->default_context = new XlaDeviceContext( + stream, client(), transfer_as_literal_, shape_representation_fn_); set_tensorflow_gpu_device_info(gpu_device_info_.get()); } @@ -247,7 +278,8 @@ Status XlaDevice::FillContextMap(const Graph* graph, TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); // Call GetAllocator for the side-effect of ensuring the allocator is created. GetAllocator({}); - auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_); + auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_, + shape_representation_fn_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -294,7 +326,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream, client(), transfer_as_literal_); + XlaTransferManager manager(stream, client(), transfer_as_literal_, + shape_representation_fn_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 3ae87308cc7cffa916e178893df70a3f314b11b0..02e88ee6793e984a7b782790f8011cbcbc5a5026 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -17,8 +17,7 @@ limitations under the License. // runtime. // // Operators assigned to an XlaDevice are compiled into XLA computations. -// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state -// is managed by XLA. +// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. // // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), // under different names (e.g., XLA_CPU or XLA_GPU). @@ -27,6 +26,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -45,12 +45,19 @@ namespace tensorflow { class XlaDevice : public LocalDevice { public: + // Given a tensor, sets `xla::Shape*` the shape of tensor's representation + // on device, fully padded. On error, the contents of `xla::Shape*` + // are undefined. + typedef std::function PaddedShapeFn; + // Wrapper class to store metadata about the XlaDevice, where it can be // retrieved e.g., when lazily creating the XlaCompilationCache device. class Metadata { public: Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type); + const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + PaddedShapeFn padded_shape_fn); // The index of the device on this host. int device_ordinal() const; @@ -58,11 +65,17 @@ class XlaDevice : public LocalDevice { se::Platform* platform() const; xla::LocalClient* client() const; const DeviceType& jit_device_type() const; + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { + return shape_representation_fn_; + } + const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } private: const int device_ordinal_; const DeviceType device_type_; se::Platform* platform_; // Not owned. + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + PaddedShapeFn padded_shape_fn_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -76,16 +89,25 @@ class XlaDevice : public LocalDevice { // 'transfer_as_literal' is true if device<->host transfers must be done using // XLA's TransferLiteral{To,From}Device interface. If false, we can use // ThenMemcpy instead. - static Status Create(const string& platform_name, const string& device_name, - int device_ordinal, const string& jit_device_name, - const SessionOptions& options, const string& name_prefix, - const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, - std::unique_ptr* device); - + // If padded_shape_fn is empty, a default implementation that returns + // the on-host shape is used. + static Status Create( + const string& platform_name, const string& device_name, + int device_ordinal, const string& jit_device_name, + const SessionOptions& options, const string& name_prefix, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device); + + // Creates a new XLA Device. + // If padded_shape_fn is empty, a default implementation that returns + // the logical on-device shape without padding is used. XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal); + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -102,6 +124,7 @@ class XlaDevice : public LocalDevice { Tensor* tensor) override; xla::LocalClient* client() const; + const Metadata& metadata() { return xla_metadata_; } xla::StatusOr GetStream(); // If not already set, create and set GpuDeviceInfo. @@ -116,8 +139,8 @@ class XlaDevice : public LocalDevice { // The name of the device that is used to compile Ops for this XlaDevice. DeviceType jit_device_name_; // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - se::Platform* platform_; // Not owned. + Allocator* xla_allocator_; // Not owned. + se::Platform* platform_; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and @@ -126,6 +149,7 @@ class XlaDevice : public LocalDevice { // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. bool transfer_as_literal_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; // If set, holds default device context (that we must Unref) // and its stream. diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index bf8c1886a022310eeaacdf69463f575a393dd8d0..71e63b110b3b132a57fc291e53a165954c72a03c 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -47,22 +47,33 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager(se::Stream* stream, - xla::LocalClient* client, - bool transfer_as_literal) +XlaTransferManager::XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) : stream_(stream), client_(client), transfer_manager_(client->backend().transfer_manager()), - transfer_as_literal_(transfer_as_literal) {} + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(std::move(shape_representation_fn)) { + if (!shape_representation_fn_) { + shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) { + return shape; + }; + } +} Status XlaTransferManager::TransferLiteralToDevice( const Tensor& host_tensor, Tensor* device_tensor) const { - xla::Literal literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(host_tensor, &literal)); - VLOG(1) << "Transfer to device as literal: " << literal.ToString(); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + xla::BorrowingLiteral literal( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(device_tensor)->shaped_buffer(); + VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + << shaped_buffer.ToString(); return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal, shaped_buffer); } @@ -75,8 +86,17 @@ Status XlaTransferManager::TransferLiteralFromDevice( TF_ASSIGN_OR_RETURN(std::unique_ptr literal, transfer_manager_->TransferLiteralFromDevice( stream_->parent(), shaped_buffer)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString(); - return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor); + VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " " + << shaped_buffer.ToString(); + Tensor tensor; + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); + // Reshape the tensor back to its declared shape. + if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { + return errors::Internal( + "Tensor::CopyFrom failed when copying from XLA device to CPU"); + } + return Status::OK(); } void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, @@ -89,16 +109,21 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, << " " << reinterpret_cast( device_tensor->tensor_data().data()) - << " " << cpu_tensor->NumElements(); + << " " << cpu_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); const int64 total_bytes = cpu_tensor->TotalBytes(); XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); CHECK(xla_tensor); + + TensorShape shape = shape_representation_fn_(device_tensor->shape(), + device_tensor->dtype()); if (!xla_tensor->has_shaped_buffer()) { Status s = xla_tensor->AllocateShapedBuffer( - device_tensor->dtype(), device_tensor->shape(), client_, + device_tensor->dtype(), shape, client_, stream_->parent()->device_ordinal()); if (!s.ok()) { done(s); @@ -106,12 +131,18 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } } - se::DeviceMemoryBase dev_dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); Status status; if (transfer_as_literal_) { - status = TransferLiteralToDevice(*cpu_tensor, device_tensor); + Tensor reshaped_cpu_tensor; + if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { + done(errors::Internal( + "Tensor::CopyFrom failed when copying from CPU to XLA device")); + return; + } + status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); } else { + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. Status block_status = stream_->BlockHostUntilDone(); @@ -142,7 +173,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, device_tensor->tensor_data().data()) << " " << reinterpret_cast(cpu_tensor->tensor_data().data()) - << device_tensor->NumElements(); + << " " << device_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); const int64 total_bytes = cpu_tensor->TotalBytes(); se::DeviceMemoryBase dev_src_ptr = @@ -171,9 +204,47 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } -XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal) - : manager_(stream, client, transfer_as_literal) {} +void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, + Tensor* dst_tensor, + const StatusCallback& done) { + // TODO(phawkins): replace this code with an asynchronous implementation. + auto body = [&]() { + if (src_tensor.NumElements() == 0) { + return Status::OK(); + } + XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor); + XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor); + CHECK(xla_src && xla_dst) + << "Missing destination tensor for device-to-device copy"; + if (!xla_dst->has_shaped_buffer()) { + TensorShape shape = + shape_representation_fn_(src_tensor.shape(), src_tensor.dtype()); + TF_RETURN_IF_ERROR( + xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, + stream_->parent()->device_ordinal())); + } + TF_RETURN_IF_ERROR( + xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + const se::DeviceMemoryBase& from_buffer = + xla_src->shaped_buffer().buffers().element(index); + CHECK_EQ(buffer->size(), from_buffer.size()); + if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer, + buffer->size())) { + return errors::Internal("Device to device memcpy failed"); + } + return Status::OK(); + })); + return Status::OK(); + }; + done(body()); +} + +XlaDeviceContext::XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) + : manager_(stream, client, transfer_as_literal, + std::move(shape_representation_fn)) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, @@ -190,4 +261,10 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, done); } +void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, + Tensor* dst_tensor, + const StatusCallback& done) { + manager_.CopyDeviceTensorToDevice(src_tensor, dst_tensor, done); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index d7f5f1d208989256f8043d2e6d93cf9bd89333b2..ee346e5653bbf9f393df202572c2150b4989506f 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/allocator.h" @@ -45,14 +46,19 @@ class XlaDeviceAllocator : public Allocator { // Helper class for managing data transfers between host and XLA devices. class XlaTransferManager { public: - explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal); + explicit XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done); + + void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, + const StatusCallback& done); + se::Stream* stream() const { return stream_; } private: @@ -69,7 +75,8 @@ class XlaTransferManager { // Transfer manager, for marshalling data to and from the device. xla::TransferManager* transfer_manager_; // True if we must use XLA's TransferManager for correct device transfers. - bool transfer_as_literal_; + const bool transfer_as_literal_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -77,8 +84,9 @@ class XlaTransferManager { // wraps the methods in XlaTransferManager. class XlaDeviceContext : public DeviceContext { public: - explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal); + explicit XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, @@ -86,6 +94,9 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, + const StatusCallback& done); + se::Stream* stream() const override { return manager_.stream(); } private: diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index f68dba6b6a26c0c289fd8457ad143d62e5fb9a69..5ecb1afa7bcec910ca843ccd3a782745f2bb6ca8 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_ops.h" +#include + #include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/jit/xla_tensor.h" namespace tensorflow { @@ -26,4 +29,82 @@ void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) { << type_string() << " on an XLA device. This should never happen."; } +XlaAssignVariableOp::XlaAssignVariableOp(OpKernelConstruction* c) + : AsyncOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); +} + +void XlaAssignVariableOp::ComputeAsync(OpKernelContext* context, + DoneCallback done) { + OP_REQUIRES_ASYNC(context, dtype_ == context->input(1).dtype(), + errors::InvalidArgument( + "Variable and value dtypes don't match; respectively, ", + dtype_, " and ", context->input(1).dtype()), + done); + Var* variable = nullptr; + OP_REQUIRES_OK_ASYNC( + context, + LookupOrCreateResource( + context, HandleFromInput(context, 0), &variable, + [this, context](Var** ptr) { + *ptr = new Var(dtype_); + PersistentTensor unused; + Tensor* tmp; + AllocatorAttributes attr; + TF_RETURN_IF_ERROR(context->allocate_persistent( + dtype_, context->input(1).shape(), &unused, &tmp, attr)); + *(*ptr)->tensor() = *tmp; + return Status::OK(); + }), + done); + core::ScopedUnref s(variable); + + OP_REQUIRES_ASYNC(context, variable->tensor()->dtype() == dtype_, + errors::InvalidArgument( + "Trying to assign variable with wrong dtype. Expected ", + DataTypeString(variable->tensor()->dtype()), " got ", + DataTypeString(dtype_)), + done); + + const Tensor& value = context->input(1); + AllocatorAttributes attr; + + // Copying is unnecessary if we are the last user of the value tensor, we can + // just adopt the input tensor's buffer instead. + std::unique_ptr input_alias = context->forward_input( + 1, /*output_index=*/OpKernelContext::Params::kNoReservation, dtype_, + value.shape(), DEVICE_MEMORY, attr); + mutex_lock ml(*variable->mu()); + variable->is_initialized = true; + if (input_alias) { + *variable->tensor() = *input_alias; + done(); + return; + } + + // Need to copy, but maybe we can re-use variable's buffer? + if (!XlaTensor::RefCountIsOne(*variable->tensor()) || + !variable->tensor()->shape().IsSameSize(value.shape())) { + // Copy to new buffer + PersistentTensor unused; + Tensor* tmp; + OP_REQUIRES_OK_ASYNC(context, + context->allocate_persistent(dtype_, value.shape(), + &unused, &tmp, attr), + done); + *variable->tensor() = *tmp; + } + + XlaDeviceContext* device_context = + static_cast(context->op_device_context()); + + variable->Ref(); + device_context->CopyDeviceTensorToDevice( + value, variable->tensor(), [context, variable, done](Status status) { + variable->Unref(); + context->SetStatus(status); + done(); + }); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 498d25cf566a91f68e5eb1ac312e17900471aeca..11e45d2823da2b623bd3cd45f7147686b05fdb2f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,16 +23,19 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" +#include "tensorflow/core/kernels/shape_ops.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { // Dummy OpKernel, used for kernels assigned to an XLA device that should be // compiled. Should never be called at runtime since such ops should be -// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an +// rewritten to a XlaLaunch op. If it is called, it means the placer placed an // operator on an XLA device but the compiler did not compile it. class XlaDeviceDummyOp : public OpKernel { public: @@ -40,8 +43,17 @@ class XlaDeviceDummyOp : public OpKernel { void Compute(OpKernelContext* ctx) override; }; +class XlaAssignVariableOp : public AsyncOpKernel { + public: + explicit XlaAssignVariableOp(OpKernelConstruction* c); + void ComputeAsync(OpKernelContext* context, DoneCallback done) override; + + private: + DataType dtype_; +}; + #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ - REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") \ + REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \ .Device(DEVICE) \ .HostMemory("constants") \ .HostMemory("resources"), \ @@ -63,13 +75,77 @@ class XlaDeviceDummyOp : public OpKernel { ConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \ + IdentityNOp); \ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ PlaceholderOp); \ \ REGISTER_KERNEL_BUILDER( \ Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ - ResourceHandleOp); + ResourceHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ + ReadVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + RankOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \ + XlaAssignVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ + ControlTriggerOp); \ + REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ + SwitchOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ + REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ + REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("LoopCond") \ + .Device(DEVICE) \ + .HostMemory("input") \ + .HostMemory("output"), \ + LoopCondOp); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..74257b09a808a39454eace3b1a9bf57a2e071360 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -0,0 +1,328 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" + +namespace tensorflow { + +// Is 'node' an operator that consumes only the shape of its input, not the +// data itself? +static bool IsShapeConsumerOp(const Node& node) { + return node.type_string() == "Shape" || node.type_string() == "ShapeN" || + node.type_string() == "Rank" || node.type_string() == "Size"; +} + +// Returns true if the op can be decomposed into XLA ops for which +// there are fusable elemental implementations. +bool IsXlaFusable(const NodeDef& node) { + static const std::unordered_set* elementwise_ops = + new std::unordered_set( + {// tf2xla/kernels/aggregate_ops.cc + "AddN", + // tf2xla/kernels/binary_ops.cc + "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", + "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", + "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", + "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", + "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", + "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", + "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", + // tf2xla/kernels/unary_ops.cc + "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", + "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", + "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", + "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", + "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", + "Square", "Tan", "Tanh", "Real", "Imag", + // tf2xla/kernels/bcast_ops.cc + "BroadcastArgs", "BroadcastGradientArgs", + // tf2xla/kernels/bias_ops.cc + "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, + // tf2xla/kernels/cast_op.cc + "Cast", + // tf2xla/kernels/concat_op.cc + "Concat", "ConcatV2", "ConcatOffset", + // tf2xla/kernels/const_op.cc + "Const", + // tf2xla/kernels/elu_op.cc + "Elu", "EluGrad", "Selu", "SeluGrad", + // tf2xla/kernels/fill_op.cc + "Fill", + // tf2xla/kernels/identity_op.cc + "Identity", "IdentityN", "PreventGradient", + "StopGradient", /*"Snapshot",*/ + // tf2xla/kernels/index_ops.cc + "ArgMax", "ArgMin", + // tf2xla/kernels/mirror_pad_op.cc + "MirrorPad", + // tf2xla/kernels/one_hot_op.cc + "OneHot", + // tf2xla/kernels/pack_op.cc + "Pack", + // tf2xla/kernels/pad_op.cc + "Pad", "PadV2", + // tf2xla/kernels/relu_op.cc + "Relu", "Relu6", "ReluGrad", "Relu6Grad", + // tf2xla/kernels/reshape_op.cc + "Reshape", + // tf2xla/kernels/reverse_op.cc + "Reverse", "ReverseV2", + // tf2xla/kernels/reverse_sequence_op.cc + "ReverseSequence", + // tf2xla/kernels/shape_op.cc + "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", + "ZerosLike", "OnesLike", + // tf2xla/kernels/slice_op.cc + "Slice", + // tf2xla/kernels/split_op.cc + "Split", "SplitV", + // tf2xla/kernels/strided_slice_op.cc + "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", + // tf2xla/kernels/tile_ops.cc + "Tile", + // tf2xla/kernels/transpose_op.cc + "Transpose", "InvertPermutation", + // tf2xla/kernels/unpack_op.cc + "Unpack"}); + + return elementwise_ops->count(node.op()) > 0; +} + +Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) { + VLOG(2) << "Here at fusion optimizer"; + + // TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op. + // Once that happens, the expected interaction between this optimizer and when + // the global_jit_level is set is as follows: Fusion optimizer will replace + // appropriate fusion clusters with XlaLaunch nodes. The remaining graph can + // be further compiled where possible via mark_for_compilation_pass. Note that + // this might lead to inefficient clustering, and it is best to use either the + // fusion optimizer or the global_jit flag, and not combine the two. + + // Create a Graph out of GraphDef. This is required currently because the + // helpers around clustering, encapsulation etc work on graphs. + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + Graph graph(function_library); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); + shape_refiner.set_require_shape_inference_fns(false); + shape_refiner.set_disable_constant_propagation(true); + ImportGraphDefOptions options; + // Graph optimization happens at the late stage of graph execution, when + // colocation constraints are already validated previously and the device + // placement of nodes has also completed, so there is no need to validate + // colocation constraints again. + options.validate_colocation_constraints = false; + options.validate_shape = false; + TF_RETURN_IF_ERROR( + ImportGraphDef(options, item.graph, &graph, &shape_refiner)); + + // Collect nodes that can be fused via XLA, while ignoring those that + // explicitly ask for XLA: (*) nodes that are marked to be compiled + // explicitly. (*) nodes assigned to XLA device. + OrderedNodeSet compilation_candidates; + for (Node* node : graph.op_nodes()) { + // If there is a _XlaCompile annotation, ignore the node if it is + // true. Nodes are marked with this attr via experimental_jit_scope, and + // will be handled by the mark_for_compilation pass. + bool compile = false; + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); + if (status.ok() && compile) { + continue; + } + // If there is already a _XlaCluster annotation, ignore the node. Nodes are + // marked with this attr to indicate they are already part of a cluster and + // hence ignored. + status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile); + if (status.ok()) { + continue; + } + + // If there is an explicit XLA device placement, ignore the node. + DeviceType device_type(""); + TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type)); + if (device_type.type_string().find("XLA") != string::npos) continue; + + // Assume all fusable ops are registered. + // TODO(hpucha): Check for registration if possible. + if (!IsXlaFusable(node->def())) { + continue; + } + + // XLA does not offer guaranteed aliasing between the input and output of + // the XLA cluster so it can't implement the forward-tensor-ref semantic. + // Leave such nodes out of XLA clusters. + if (HasForwardedRefInput(*node)) { + continue; + } + + compilation_candidates.insert(node); + } + + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + *output = item.graph; + return Status::OK(); + } + + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + + // TODO(hpucha): Make clustering more robust. There are two known issues that + // we need to mitigate: (a) Non-resource variables can cause deadlocks + // when clustering changes order of execution. See b/77263461 for a specific + // example. (b) Queue operations can also cause deadlocks. See b/77261498 for + // example. + + struct Cluster { + // Identifies the node that represents this cluster in the cycle detection + // graph. + int representative = -1; + }; + + // Each compilation candidate belongs to a cluster. The cluster's + // representative names the node in the 'cycles' graph that represents the + // cluster. + std::vector> clusters(graph.num_node_ids()); + std::deque*> worklist; + for (Node* node : compilation_candidates) { + Cluster& cluster = clusters[node->id()].Get(); + cluster.representative = node->id(); + worklist.push_back(&clusters[node->id()]); + } + + // Repeatedly contract edges between clusters that are on the same device, + // provided the contraction would not create a cycle. This is a simplified + // version of the clustering in mark_for_compilation_pass that also deals with + // nodes that are explicitly tagged to be compiled/clustered. + while (!worklist.empty()) { + int from = worklist.front()->Get().representative; + worklist.pop_front(); + + Node* node_from = graph.FindNodeId(from); + if (node_from->IsControlFlow()) { + // Control flow nodes aren't compilation candidates and should never + // appear. + return errors::Internal( + "Found control flow node in clustering worklist: ", + node_from->type_string()); + } + for (int to : cycles.Successors(from)) { + if (to >= graph.num_node_ids()) { + // Node is a "frame" node that is present only in the cycle detection + // graph. No clustering is possible. + continue; + } + Node* node_to = graph.FindNodeId(to); + if (compilation_candidates.find(node_to) == + compilation_candidates.cend()) { + continue; + } + + // Do not cluster across devices. + if (node_from->def().device() != node_to->def().device()) { + VLOG(2) << "Devices " << node_from->def().device() << " " + << node_to->def().device(); + VLOG(2) << "Device names " << node_from->assigned_device_name() << " " + << node_to->assigned_device_name(); + continue; + } + + // Ops that consume shapes cannot be the root of a cluster. This is an + // optimization. + if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { + continue; + } + + // If contracting the edge would create a cycle, bail out. + // However, just because we can't merge the clusters now does not mean + // we won't be able to merge them in the future. + // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge + // 1->3. But if we first contract 1->2 then we can later contract 1->3. + if (!cycles.ContractEdge(from, to)) continue; + + // Merge the clusters. ContractEdge uses 'from' as the number of the + // merged node, so make sure 'from' is the chosen representative. + clusters[from].Merge(&clusters[to]); + + worklist.push_back(&clusters[from]); + break; + } + } + + // Count the number of non-trivial elements in each cluster. + std::vector effective_cluster_sizes(graph.num_node_ids()); + for (const Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + // Identity nodes will be removed if the node gets marked for compilation. + // Therefore we don't want to count them towards the effective cluster size. + if (n->def().op() != "Identity") { + effective_cluster_sizes[cluster]++; + } + } + + const int min_cluster_size = 2; + int num_clusters = 0; + for (auto size : effective_cluster_sizes) { + if (size >= min_cluster_size) { + VLOG(3) << "Cluster " << num_clusters << " " << size; + num_clusters++; + } + } + + // Names for each cluster. + std::unordered_map cluster_names; + // Sequence number generator to ensure clusters have unique names. + static std::atomic cluster_sequence_num; + + for (Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + + // Compile if this is a cluster of >= min_cluster_size compilable operators. + if (effective_cluster_sizes[cluster] >= min_cluster_size) { + string& name = cluster_names[cluster]; + + if (name.empty()) { + name = strings::StrCat("cluster_", cluster_sequence_num++); + } + n->AddAttr(kXlaClusterAttr, name); + VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; + } + } + + graph.ToGraphDef(output); + return Status::OK(); +} + +REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.h b/tensorflow/compiler/jit/xla_fusion_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..3d2309e782d38725f8db025fbfda0bf0f63d18be --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { + +// Optimizes graphs by fusing ops where possible, resulting in more efficient +// execution. +class XlaFusionOptimizer : public grappler::CustomGraphOptimizer { + public: + XlaFusionOptimizer() {} + ~XlaFusionOptimizer() override {} + + Status Init( + const RewriterConfig_CustomGraphOptimizer* config = nullptr) override { + return Status::OK(); + } + + string name() const override { return "xla-fusion"; }; + + Status Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) override; + + void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item, + const GraphDef& optimize_output, double result) override { + // Nothing to do for XlaFusionOptimizer. + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5736760a878dc857a8558093054d0adc0f727398 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +REGISTER_OP("UncompilableNullary").Output("o: float"); +REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); + +class XlaFusionOptimizerTest : public grappler::GrapplerTest { + protected: + std::unordered_map GetClusters(const GraphDef& graph) { + std::unordered_map ids; + for (const NodeDef& node : graph.node()) { + string cluster; + if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) { + CHECK(!cluster.empty()); + ids[node.name()] = cluster; + } + } + return ids; + } +}; + +TEST_F(XlaFusionOptimizerTest, Chains) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); + Node* d = + ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); + ops::UnaryOp("Relu", e, builder.opts().WithName("F")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(4, clusters.size()); + EXPECT_EQ(clusters["B"], clusters["C"]); + EXPECT_EQ(clusters["E"], clusters["F"]); + EXPECT_NE(clusters["B"], clusters["E"]); + EXPECT_TRUE(clusters.find("A") == clusters.cend()); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, FusableOps) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(2, clusters.size()); + EXPECT_EQ(clusters["C"], clusters["E"]); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp( + "Add", a, b, + builder.opts().WithName("C").WithDevice("/device:XLA_CPU")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + ops::UnaryOp("Cos", e, + builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true)); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, UncompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = + ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, CompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(3, clusters.size()); + EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_EQ(clusters["A"], clusters["C"]); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index a8afbf9dcd736bb292b7c5f52c7cce2b47fb85b6..c0d86a28c7698c302e28bab972bb2f847cc00ca4 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -48,7 +48,9 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, Status status = XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, registration, - /*transfer_as_literal=*/false, &device); + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 9e098c46f422b436c722bb909dc58930ab7c0ef6..661187f4a873b03b8d013aa74cb6b6315bb4e2eb 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -51,7 +51,9 @@ Status XlaInterpreterDeviceFactory::CreateDevices( TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, options, name_prefix, registration, - /*transfer_as_literal=*/false, &device)); + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 0223f97a032cf9efe56005248ce65d412e340b78..d0c7a9365125708b2af43f87c7617d8d84050a61 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -60,19 +60,22 @@ XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped) XlaAllocator::~XlaAllocator() {} -xla::StatusOr XlaAllocator::Allocate( +xla::StatusOr XlaAllocator::Allocate( int device_ordinal, uint64 size, bool retry_on_failure) { - void* data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size); + AllocationAttributes attrs; + attrs.no_retry_on_failure = !retry_on_failure; + void* data = + wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs); if (data == nullptr) { return errors::ResourceExhausted("Out of memory while trying to allocate ", size, " bytes."); - } else { - return se::DeviceMemoryBase(data, size); } + return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size), + device_ordinal, this); } -Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) { - wrapped_->DeallocateRaw(mem->opaque()); +Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { + wrapped_->DeallocateRaw(mem.opaque()); return Status::OK(); } @@ -192,11 +195,6 @@ void XlaComputationLaunchContext::PopulateOutputs( OP_REQUIRES_OK( ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); - if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { - OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer( - const_tensor.dtype(), const_tensor.shape(), - client_, stream->parent()->device_ordinal())); - } Device* device = dynamic_cast(ctx->device()); OP_REQUIRES(ctx, device != nullptr, @@ -238,7 +236,7 @@ void XlaComputationLaunchContext::PopulateOutputs( } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); ctx->set_output(i, output_tensor); } ++output_num; @@ -288,7 +286,7 @@ void XlaComputationLaunchContext::PopulateOutputs( } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( write.type, write.shape, buffer, allocator); - output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); *variable->tensor() = output_tensor; } ++output_num; diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index a2431253f8c44bdd9b99a253f79bdb14722d7c72..4390701ccbd0bc3971413ddcd917c11019990087 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" @@ -50,9 +52,9 @@ class XlaAllocator : public xla::DeviceMemoryAllocator { public: XlaAllocator(const se::Platform* platform, Allocator* wrapped); ~XlaAllocator() override; - xla::StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; - Status Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) override; + xla::StatusOr Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // The Tensorflow BFC allocator used on GPU allows host-side deallocation // before GPU execution takes place. Tensorflow uses the ordering of the main diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc index 27813efc0bc0aecdbea2dfce5ca27ba704ea45e2..a45932403ec1760d6b985d5357fd6d84fbf257a2 100644 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -36,9 +36,9 @@ void BM_ExtractSubBuffer(int iters, int depth, int fan_out) { for (int i = 0; i < iters; ++i) { // Extract a buffer from approximately the middle of the first level of the // tree. - tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, - /*index=*/fan_out / 2, - /*allocator=*/nullptr) + (void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, + /*index=*/fan_out / 2, + /*allocator=*/nullptr) .release(); } } diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index ce6456880bc1b3bc15ac0ef4bae35a83771098ef..3c44c4ae6df7f3e2d60d8933561c0c71888e8c3f 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -18,7 +18,7 @@ limitations under the License. namespace tensorflow { -/*static*/ XlaTensor* XlaTensor::FromTensor(Tensor* tensor) { +/*static*/ XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) { if (tensor->NumElements() == 0) { return nullptr; } @@ -27,8 +27,8 @@ namespace tensorflow { return xla_tensor; } -/*static*/ const XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) { - return FromTensor(const_cast(tensor)); +/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) { + return tensor.RefCountIsOne(); } /*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor( @@ -52,20 +52,24 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, client->backend().transfer_manager()->HostShapeToDeviceShape( on_host_shape); - xla::ShapedBuffer buffer(on_host_shape, on_device_shape, client->platform(), - device_ordinal); - for (auto& index_to_buffer : buffer.buffers()) { + xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, + client->backend().memory_allocator(), + device_ordinal); + for (auto& index_to_buffer : shaped_buffer.buffers()) { xla::Shape subshape = xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); uint64 size = client->backend().transfer_manager()->GetByteSizeRequirement(subshape); - TF_ASSIGN_OR_RETURN(index_to_buffer.second, + TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer, client->backend().memory_allocator()->Allocate( device_ordinal, size, /*retry_on_failure=*/false)); + // Move our buffer into shaped_buffer, which takes ownership of it. + index_to_buffer.second = buffer.Forget(); } - set_shaped_buffer(xla::ScopedShapedBuffer( - std::move(buffer), client->backend().memory_allocator())); + VLOG(4) << shaped_buffer.ToString(); + + set_shaped_buffer(std::move(shaped_buffer)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 6b29c82ec11e39ad525663991e179443c2b6dca7..c54001a999998f45c0cdacd752ca4036f0792857 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -34,10 +34,9 @@ class XlaTensor { public: // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast // fails. - static XlaTensor* FromTensor(Tensor* tensor); - // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast - // fails. - static const XlaTensor* FromTensor(const Tensor* tensor); + static XlaTensor* FromTensor(const Tensor* tensor); + + static bool RefCountIsOne(const Tensor& tensor); // Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in // which case the returned value is shaped_buffer()->root_buffer(), or a @@ -62,6 +61,10 @@ class XlaTensor { CHECK(has_shaped_buffer()); return *shaped_buffer_; } + xla::ShapedBuffer& shaped_buffer() { + CHECK(has_shaped_buffer()); + return *shaped_buffer_; + } // Mutates the XlaTensor to set the ShapedBuffer. void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { shaped_buffer_ = diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 9791792f29ca05f4ece77cca6305ed05343d1d38..e6c92f9720e1285617280f60d1c5fea443c5ebef 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -42,7 +42,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform", "//tensorflow/python:random_seed", "//tensorflow/python:session", @@ -58,7 +58,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -72,7 +72,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -93,7 +93,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -111,7 +111,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:bitwise_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -120,6 +120,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "bucketize_op_test", + size = "small", + srcs = ["bucketize_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "categorical_op_test", size = "small", @@ -127,7 +140,7 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -141,7 +154,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -156,7 +169,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -170,7 +183,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -184,7 +197,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -196,9 +209,11 @@ tf_xla_py_test( name = "oom_test", size = "medium", srcs = ["oom_test.py"], + # TODO(b/80081500): Re-enable on GPU. Disabled on 2018-05-21. disabled_backends = [ "cpu", "cpu_ondemand", + "gpu", ], tags = [ # Allocates very large amounts of memory and does not work under TSAN. @@ -209,7 +224,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -225,7 +240,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -241,7 +256,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -263,7 +278,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -291,7 +306,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -307,7 +322,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -326,7 +341,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn", @@ -346,7 +361,7 @@ tf_xla_py_test( "//tensorflow/contrib/signal:signal_py", "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:spectral_ops", ], @@ -360,7 +375,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -372,7 +387,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -388,7 +403,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -403,12 +418,27 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:image_ops", "//tensorflow/python:platform_test", ], ) +tf_xla_py_test( + name = "listdiff_op_test", + size = "small", + srcs = ["listdiff_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform_test", + "@six_archive//:six", + ], +) + tf_xla_py_test( name = "lrn_ops_test", size = "medium", @@ -416,7 +446,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -431,7 +461,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -443,7 +473,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -457,7 +487,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -470,7 +500,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -483,7 +513,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -498,7 +528,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -515,7 +545,9 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -530,7 +562,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -546,7 +578,7 @@ tf_xla_py_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -559,7 +591,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", ], ) @@ -571,7 +603,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -583,7 +615,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -598,7 +630,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -611,7 +643,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:platform_test", @@ -626,7 +658,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -642,7 +674,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -655,7 +687,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/contrib/stateless", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -669,7 +701,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -688,7 +720,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -701,7 +733,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -715,7 +747,7 @@ tf_xla_py_test( srcs = ["fused_batchnorm_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn", @@ -734,7 +766,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -753,7 +785,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:training", ], @@ -768,7 +800,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -780,7 +812,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -793,21 +825,34 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) -cuda_py_test( +tf_xla_py_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "xla_device_gpu_test", + size = "small", + srcs = ["xla_device_gpu_test.py"], additional_deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", ], ) @@ -824,11 +869,22 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", - "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", + ], +) + +cuda_py_test( + name = "dense_layer_test", + size = "small", + srcs = ["dense_layer_test.py"], + additional_deps = [ + "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:layers", "//tensorflow/python:variables", ], ) @@ -872,7 +928,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:variables", @@ -887,7 +943,7 @@ cuda_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -925,7 +981,7 @@ tf_xla_py_test( srcs = ["fake_quant_ops_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -937,7 +993,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index ec547e16cd9c91a1e25bc963b9a3cafddf7326cd..9d3a889b1f54c813e881bb03b5275f809af1b3c8 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -29,51 +29,70 @@ from tensorflow.python.platform import test class ArgMinMaxTest(xla_test.XLATestCase): - def _assertOpOutputMatchesExpected(self, op, inp, expected): - """Verifies that 'op' produces 'expected' when fed input 'inp' . + def _assertOpOutputMatchesExpected(self, op, axis, output_type, op_input, + expected): + """Verifies that 'op' produces 'expected' when fed input 'op_input' . Args: - op: operator to test - inp: numpy input array to use as input to 'op'. + op: argmin or argmax operator to test. + axis: integer axis to reduce across. + output_type: numpy datatype of the output to produce. + op_input: numpy input array to use as input to 'op'. expected: numpy array representing the expected output of 'op'. """ with self.test_session() as session: with self.test_scope(): pinp = array_ops.placeholder( - dtypes.as_dtype(inp.dtype), inp.shape, name="a") - output = op(pinp) - result = session.run(output, {pinp: inp}) + dtypes.as_dtype(op_input.dtype), op_input.shape, name="a") + output = op(pinp, axis=axis, output_type=output_type) + result = session.run(output, {pinp: op_input}) self.assertAllEqual(result, expected) def testArgMinMax(self): # Complex numbers do not support argmin/argmax. minmax_types = set(self.numeric_types) - set(self.complex_types) for dtype in minmax_types: - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), - np.array([1, 10, 27, 3, 3, 4], dtype=dtype), - expected=np.int32(2)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), - np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), - expected=np.array([0, 1, 0], dtype=np.int32)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=1, output_type=dtypes.int32), - np.array([[4, 1], [3, 2]], dtype=dtype), - expected=np.array([0, 0], dtype=np.int32)) + # output_type is a numpy data type that is used to specify the desired + # output type of the op as well as to convert the Python number to the + # array scalar of the type. + for output_type in self.int_types: + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=0, + output_type=output_type, + op_input=np.array([1, 10, 27, 3, 3, 4], dtype=dtype), + expected=output_type(2)) + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=0, + output_type=output_type, + op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([0, 1, 0], dtype=output_type)) + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=1, + output_type=output_type, + op_input=np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([0, 0], dtype=output_type)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), - np.array([3, 10, 27, 3, 2, 4], dtype=dtype), - expected=np.int32(4)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), - np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), - expected=np.array([1, 0, 1], dtype=np.int32)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=1, output_type=dtypes.int32), - np.array([[4, 1], [3, 2]], dtype=dtype), - expected=np.array([1, 1], dtype=np.int32)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=0, + output_type=output_type, + op_input=np.array([3, 10, 27, 3, 2, 4], dtype=dtype), + expected=output_type(4)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=0, + output_type=output_type, + op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([1, 0, 1], dtype=output_type)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=1, + output_type=output_type, + op_input=np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([1, 1], dtype=output_type)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fde9759a1c209844caac99d5f303cd3e406e5370 --- /dev/null +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -0,0 +1,78 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for bucketize_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class BucketizationOpTest(XLATestCase): + + def testInt(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual(expected_out, + sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]})) + + def testFloat(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual( + expected_out, + sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]})) + + def test2DInput(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [[0, 1, 1, 2, 2], [3, 3, 4, 4, 1]] + self.assertAllEqual( + expected_out, sess.run(op, + {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) + + def testInvalidBoundariesOrder(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11]) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Expected sorted boundaries"): + sess.run(op, {p: [-5, 0]}) + + def testBoundariesNotList(self): + with self.test_session(): + with self.assertRaisesRegexp(TypeError, "Expected list.*"): + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + math_ops._bucketize(p, boundaries=0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..865f60ccab46ec6829e49409508303052944e13b --- /dev/null +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -0,0 +1,135 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DenseLayer JIT compilation on the CPU and GPU devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.contrib.compiler import jit +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.layers import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +jit_scope = jit.experimental_jit_scope + + +def GetRunMetadataLabels(run_metadata): + """Returns all labels in run_metadata.""" + labels = [] + for dev_stats in run_metadata.step_stats.dev_stats: + for node_stats in dev_stats.node_stats: + labels.append(node_stats.timeline_label) + return labels + + +def InLabels(labels, substr): + """Returns true iff one of the labels contains substr.""" + return any([substr in x for x in labels]) + + +def XlaLaunchOpCount(labels): + """Count how many XlaLaunch labels are present.""" + return sum("XlaLaunch(" in x for x in labels) + + +class DenseLayerTest(test.TestCase): + + def testDenseLayerAutoJit(self): + """Tests dense layer compilation in auto-jit mode. + + Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. + """ + + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.global_jit_level = ( + config_pb2.OptimizerOptions.ON_1) + + with self.test_session(config=config) as sess: + x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(1, XlaLaunchOpCount(labels)) + self.assertFalse(InLabels(labels, "ListDiff")) + + def testDenseLayerJitScopeDefinedShape(self): + """Tests that the dense layer node is properly compiled in jit scope. + + Dense layer with static shape input tensor should be compiled into a single + XlaLaunch op by XLA. + """ + + with self.test_session() as sess: + x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32) + with jit_scope(): + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(1, XlaLaunchOpCount(labels)) + # No need to check whether ListDiff is compiled or not because ListDiff op + # is not used when input tensor shape is fully defined. + + def testDenseLayerJitScopeUndefinedShape(self): + """Tests that the dense layer node is properly compiled in jit scope. + + Dense layer uses shape op to get shape of input tensor if its shape is not + fully defined. XLA does not cluster shape op with other operators. But in + experimental_jit_scope, XLA is forced to compile shape op into its own + cluster, causing dense layer to be split into TWO XlaLaunch ops. + """ + + with self.test_session() as sess: + x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) + with jit_scope(): + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(2, XlaLaunchOpCount(labels)) + self.assertFalse(InLabels(labels, "ListDiff")) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 0a0d335ca76dd7ec7ca3b12f9e8a83b596daa07e..03d96a2cd8ab22a472a67f092e36224820405fa8 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -153,7 +153,7 @@ class DepthwiseConv2DTest(XLATestCase): dtype=data_type).reshape(filter_in_sizes) with self.test_session() as sess: if data_type == np.float32: - tolerance = 1e-5 + tolerance = 1e-4 else: self.assertEqual(data_type, np.float64) tolerance = 1e-8 @@ -339,7 +339,7 @@ class DepthwiseConv2DTest(XLATestCase): gpu_value = _GetVal(use_xla=True) cpu_value = _GetVal(use_xla=False) - self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4) + self.assertAllClose(cpu_value, gpu_value, rtol=1e-3, atol=1e-3) def testDepthwiseConv2DInputGradCompare(self): for index, (input_size, filter_size, output_size, stride, diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 5ab1585f8c6e07d6e3f0f40c99840b176492e523..fceb61ef879ed53a09954513ad9487263de0fe0a 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -117,6 +117,15 @@ class EagerTest(XLATestCase): v.assign_add(2.0) self.assertEqual(3.0, v.numpy()) + def testReadAssignRead(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + val1 = v.read_value() + v.assign_add(2.0) + val2 = v.read_value() + self.assertEqual(1.0, val1.numpy()) + self.assertEqual(3.0, val2.numpy()) + def testGradient(self): def f(x): return x @@ -136,6 +145,92 @@ class EagerTest(XLATestCase): grads = backprop.implicit_grad(f)() self.assertEqual(2., grads[0][0].numpy()) + def testMultipleVariableReads(self): + # This test makes sure consecutive variable reads don't copy + # the underlying memory. + with self.test_scope(): + # Create 128MiB variables + var = resource_variable_ops.ResourceVariable( + array_ops.ones([32, 1024, 1024])) + + # Read the same variable 100 times. If the underlying tensor + # is not copied, this is a trivial operation. If it is copied, + # this will eat over 13GB and OOM. + values = [] + for _ in range(100): + values.append(var.value()) + + # The shape, shape_n, size, and rank are tested here because their + # execution kernels (as opposed to compilation only tf2xla kernels) + # are distincts from tf2xla kernels. + + def testShape(self): + def const(value): + return array_ops.shape( + constant_op.constant(value)).numpy() + + def ones(value): + return array_ops.shape( + array_ops.ones(value)).numpy() + + with self.test_scope(): + # Shapes of directly constructed tensors + self.assertAllEqual([], const(3)) + self.assertAllEqual([3], const([1.0, 2.0, 3.0])) + self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]])) + self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]])) + + # Shapes of tensors created by op running on device + # We make this distinction because directly constructed tensors + # are treated differently in a few places that can influence shape: + # - they always have on_host_tensor + # - they and their shapes can be cached + # - they end up on device via a copy, instead of as program output + self.assertAllEqual([], ones([])) + self.assertAllEqual([3], ones([3])) + self.assertAllEqual([2, 2], ones([2, 2])) + self.assertAllEqual([2, 1, 2], ones([2, 1, 2])) + + def testShapeN(self): + with self.test_scope(): + # Shapes of directly constructed tensors + shapes = array_ops.shape_n([ + constant_op.constant(1.0), + constant_op.constant([1.0, 2.0, 3.0]), + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + # Shapes of tensors created by op running on device + shapes = array_ops.shape_n([ + array_ops.ones([]), + array_ops.ones([3]), + array_ops.ones([2, 2])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + def testSize(self): + with self.test_scope(): + self.assertEqual( + 1, array_ops.size(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 4, array_ops.size( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + + def testRank(self): + with self.test_scope(): + self.assertEqual( + 0, array_ops.rank(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 2, array_ops.rank( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + class EagerFunctionTest(XLATestCase): @@ -234,6 +329,74 @@ class EagerFunctionTest(XLATestCase): self.assertAllEqual([[1.]], c.numpy()) self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) + def testDefunInGradientTape(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun(compiled=True) + def f(x): + x = v0 * v0 * x + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + y = f(x) + dy = tape.gradient(y, v0) + + self.assertEqual(75, y.numpy()) + self.assertEqual(30, dy.numpy()) + + +class ExcessivePaddingTest(XLATestCase): + """Test that eager execution works with TPU flattened tensors. + + Tensors that would normally be excessively padded when written + to TPU memory are reshaped to 1-D flat tensors. + + This test case verifies that such tensors work with eager execution. + + The flattening currently only happens on TPU, but tests should work + fine with all backends as flattening is transparent. + """ + + def testFromConstant(self): + with self.test_scope(): + # Create constant of shape [100, 2, 1]. This tensor would be + # excessively padded on TPU. + tensor = constant_op.constant(100 * [[[10.0], [2.0]]]) + # Use reduce_sum since it requires correctly working with + # a particular dimension. + reduced = math_ops.reduce_sum(tensor, axis=1) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testFromOperation(self): + with self.test_scope(): + tensor = array_ops.ones([3, 100, 2, 2]) + reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3]) + self.assertAllEqual(100 * [12.0], reduced) + + def testAsFunctionInput(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x): + return math_ops.reduce_sum(x, axis=2) + + tensor = constant_op.constant(100 * [[[10.0, 2.0]]]) + reduced = f(tensor) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testAsFunctionOutput(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x): + return x * constant_op.constant(100 * [[[10.0, 2.0]]]) + + y = f(3) + reduced = math_ops.reduce_sum(y, axis=2) + self.assertAllEqual(100 * [[36.0]], reduced) + if __name__ == '__main__': ops.enable_eager_execution( diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index fbc3c994d163a504351fcccd1ba71a0997e6516f..8a3f4b0bdc7a61d6cfa2ba7474ce8579e293a5c7 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -24,12 +24,10 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -@test_util.with_c_api class FunctionTest(XLATestCase): def testFunction(self): diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 42e637734c578fcc70473060cb156e172a0a1995..7cf953ef25ef5daf8a6d4fc9985ed8dbfb2081e5 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -65,9 +65,7 @@ class RGBToHSVTest(XLATestCase): join1 = array_ops.stack(split1) join2 = array_ops.stack(split2) batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2], - { - batch0: inp - }) + {batch0: inp}) # Verify that processing batch elements together is the same as separate self.assertAllClose(batch1, join1) @@ -401,9 +399,7 @@ class AdjustSaturationTest(XLATestCase): x = array_ops.placeholder(dtypes.float32, shape=x_shape) with self.test_scope(): y_fused = self._adjust_saturation(x, - scale).eval(feed_dict={ - x: x_np - }) + scale).eval(feed_dict={x: x_np}) self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) @@ -412,7 +408,8 @@ class ResizeBilinearTest(XLATestCase): def _assertForwardOpMatchesExpected(self, image_np, target_shape, - expected=None): + expected=None, + large_tolerance=False): if expected is None: self.fail("expected must be specified") with self.test_session() as sess, self.test_scope(): @@ -420,7 +417,11 @@ class ResizeBilinearTest(XLATestCase): resized = gen_image_ops.resize_bilinear( image, target_shape, align_corners=True) out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) - self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + if large_tolerance: + self.assertAllClose( + expected[np.newaxis, :, :, np.newaxis], out, rtol=0.03, atol=0.1) + else: + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) def _assertBackwardOpMatchesExpected(self, grads_np, @@ -555,6 +556,28 @@ class ResizeBilinearTest(XLATestCase): [[12.5, 27.5, 21.875], [42.5, 80.0, 57.5], [40.625, 72.5, 50]], dtype=np.float32)) + def testAlignCorners4x4To8x8(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3]], dtype=np.float32) + np.array( + [[0], [1], [2], [3]], dtype=np.float32)) * 7.0, [8, 8], + expected=3 * + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)), + large_tolerance=True) + + def testAlignCorners8x8To16x16(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0, + [16, 16], + expected=7 * (np.array( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], + dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], + [12], [13], [14], [15]], + dtype=np.float32)), + large_tolerance=True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 1ad83d80409734efd1f5a0a9fc39f5b7d064d54b..6e0db54b7a74b284dc7d18bcbb07c178c664c1e5 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -29,13 +29,11 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.layers import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import test jit_scope = jit.experimental_jit_scope @@ -80,10 +78,10 @@ def InLabels(labels, substr): def MetadataHasXlaLaunch(run_metadata): - """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline.""" + """Returns true if there is a XlaLaunch kernel in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch") + return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch") class JitLaunchTest(test.TestCase): @@ -92,8 +90,8 @@ class JitLaunchTest(test.TestCase): # Verifies that the outputs match and that XLA was invoked. 'fn' must take # the same number of tensors as arguments that are in 'args', and must return # a tuple of output tensors. - # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node - # actually ran. However, it is sometimes possible for _XlaLaunch ops to be + # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node + # actually ran. However, it is sometimes possible for XlaLaunch ops to be # constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): with session_lib.Session(config=NoRewriteSessionConfig()) as sess: @@ -127,7 +125,7 @@ class JitLaunchTest(test.TestCase): for (x, y) in zip(compiled, direct): self.assertAllClose(x, y, rtol=1e-1) else: - self.assertAllClose(compiled, direct) + self.assertAllClose(compiled, direct, rtol=1e-2) def testNoOutputs(self): with session_lib.Session() as sess: @@ -443,31 +441,14 @@ class XlaCompilationTest(test.TestCase): self.assertFalse(InLabels(labels, "Log")) self.assertTrue(InLabels(labels, "Reciprocal")) self.assertTrue(InLabels(labels, "Mul")) - self.assertFalse(InLabels(labels, "_XlaLaunch")) + self.assertFalse(InLabels(labels, "XlaLaunch")) - # Compile the backprop. One _XlaLaunch. + # Compile the backprop. One XlaLaunch. labels = _Run(compiled=True) self.assertFalse(InLabels(labels, "Log")) self.assertFalse(InLabels(labels, "Reciprocal")) self.assertFalse(InLabels(labels, "Mul")) - self.assertTrue(InLabels(labels, "_XlaLaunch")) - - def testDenseLayer(self): - """Tests that the dense layer node is properly compiled.""" - - with self.test_session(config=NoRewriteSessionConfig()) as sess: - x = array_ops.placeholder(shape=[2, 3], dtype=np.float32) - with jit_scope(): - y = layers.dense(x, 3) - - sess.run(variables.initialize_all_variables()) - run_metadata = config_pb2.RunMetadata() - sess.run(y, {x: np.array([[1, 2, 3], [4, 5, 6]])}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assertTrue(InLabels(labels, "XlaLaunch")) class ElementWiseFusionTest(test.TestCase): @@ -501,7 +482,7 @@ class ElementWiseFusionTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = RunMetadataLabels(run_metadata) - count = sum("_XlaLaunch(" in x for x in labels) + count = sum("XlaLaunch(" in x for x in labels) return output, count diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..45a04f0cf56e88946b946bedacb25ce6da3121b4 --- /dev/null +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA listdiff operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ListDiffTest(xla_test.XLATestCase): + + def _testListDiff(self, x, y, out, idx): + for dtype in [dtypes.int32, dtypes.int64]: + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.test_session() as sess: + x_tensor = ops.convert_to_tensor(x, dtype=dtype) + y_tensor = ops.convert_to_tensor(y, dtype=dtype) + with self.test_scope(): + out_tensor, idx_tensor = array_ops.listdiff( + x_tensor, y_tensor, out_idx=index_dtype) + tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + self.assertAllEqual(out, tf_out) + self.assertAllEqual(idx, tf_idx) + self.assertEqual(1, out_tensor.get_shape().ndims) + self.assertEqual(1, idx_tensor.get_shape().ndims) + + def testBasic1(self): + self._testListDiff(x=[1, 2, 3, 4], y=[1, 2], out=[3, 4], idx=[2, 3]) + + def testBasic2(self): + self._testListDiff(x=[1, 2, 3, 4], y=[2], out=[1, 3, 4], idx=[0, 2, 3]) + + def testBasic3(self): + self._testListDiff(x=[1, 4, 3, 2], y=[4, 2], out=[1, 3], idx=[0, 2]) + + def testDuplicates(self): + self._testListDiff(x=[1, 2, 4, 3, 2, 3, 3, 1], + y=[4, 2], + out=[1, 3, 3, 3, 1], + idx=[0, 3, 5, 6, 7]) + + def testRandom(self): + num_random_tests = 10 + int_low = -7 + int_high = 8 + max_size = 50 + for _ in xrange(num_random_tests): + x_size = np.random.randint(max_size + 1) + x = np.random.randint(int_low, int_high, size=x_size) + y_size = np.random.randint(max_size + 1) + y = np.random.randint(int_low, int_high, size=y_size) + out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y] + if out_idx: + out, idx = map(list, zip(*out_idx)) + else: + out = [] + idx = [] + self._testListDiff(list(x), list(y), out, idx) + + def testFullyOverlapping(self): + self._testListDiff(x=[1, 2, 3, 4], y=[1, 2, 3, 4], out=[], idx=[]) + + def testNonOverlapping(self): + self._testListDiff(x=[1, 2, 3, 4], + y=[5, 6], + out=[1, 2, 3, 4], + idx=[0, 1, 2, 3]) + + def testEmptyX(self): + self._testListDiff(x=[], y=[1, 2], out=[], idx=[]) + + def testEmptyY(self): + self._testListDiff(x=[1, 2, 3, 4], y=[], out=[1, 2, 3, 4], idx=[0, 1, 2, 3]) + + def testEmptyXY(self): + self._testListDiff(x=[], y=[], out=[], idx=[]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index d6c93088d4efff7d8306e262a79ae49d3d8ac722..f13dff96203b5480480c2a2fc9ac38ca78b7f78a 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -22,6 +22,8 @@ import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import googletest @@ -47,18 +49,18 @@ class RandomOpsTest(XLATestCase): # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. self.assertTrue((not np.array_equal(y, z)) or - (not np.array_equal(z, w)) or - (not np.array_equal(y, w))) + (not np.array_equal(z, w)) or (not np.array_equal(y, w))) def testRandomUniformIsNotConstant(self): + def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, - maxval=1000000) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) def testRandomNormalIsNotConstant(self): + def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) @@ -70,12 +72,20 @@ class RandomOpsTest(XLATestCase): for dtype in self._random_types(): with self.test_session() as sess: with self.test_scope(): - x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2, - maxval=33) + x = random_ops.random_uniform( + shape=[1000], dtype=dtype, minval=-2, maxval=33) y = sess.run(x) self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) + def testTruncatedNormalIsNotConstant(self): + + def rng(dtype): + return random_ops.truncated_normal(shape=[2], dtype=dtype) + + # TODO(b/34339814): implement inverse erf support for non-F32 types. + self._testRngIsNotConstant(rng, dtypes.float32) + def testTruncatedNormalIsInRange(self): count = 10000 # TODO(b/34339814): implement inverse erf support for non-F32 types. @@ -87,6 +97,29 @@ class RandomOpsTest(XLATestCase): self.assertTrue((y >= -2).sum() == count) self.assertTrue((y <= 2).sum() == count) + def testShuffle1d(self): + with self.test_session() as sess: + with self.test_scope(): + x = math_ops.range(20) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = range(20) + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(set(result), set(expected)) + + def testShuffle2d(self): + with self.test_session() as sess: + with self.test_scope(): + x = array_ops.diag(math_ops.range(20)) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = np.diag(range(20)).flatten() + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(len(result.flatten()), len(expected)) + self.assertAllEqual(set(result.flatten()), set(expected)) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index e53efc3091d8935e745122af29abd7b8063b1d01..16f293891d56d78885dd515bb7b9899faf0690f7 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -619,8 +619,8 @@ std::vector OpTest::ImageDims(TensorFormat format, int batch, dims.push_back(dim); } break; - case FORMAT_NCHW_VECT_C: - LOG(FATAL) << "FORMAT_NCHW_VECT_C not supported."; + default: + LOG(FATAL) << "Tensor format " << ToString(format) << " not supported."; } return dims; } diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index ba79f393a8f9b24ac506d2130957c38ecd442509..689a4a1f4e02f5dd48f64dc94afd0fcb50df8b5b 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -209,7 +209,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.expm1, np.array([[-1, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype)) + expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), + rtol=1e-5) self._assertOpOutputMatchesExpected( math_ops.floor, @@ -251,12 +252,12 @@ class UnaryOpsTest(XLATestCase): np.array([[1, 2]], dtype=dtype), expected=np.array([[0.540297, -0.41614]], dtype=dtype)) - # TODO(b/34703906): improve log1p implementation and make tolerance - # tighter. self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), - expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype))) + expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)), + rtol=1e-4, + atol=1e-6) self._assertOpOutputMatchesExpected( math_ops.rint, @@ -333,13 +334,19 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.elu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-6]], dtype=dtype), + expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.selu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-5]], dtype=dtype), + expected=np.array( + [[-1.11133074, 0., 1.05070099, -1.758090550379974e-05]], + dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.relu, @@ -419,7 +426,9 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.expm1, np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype), - expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype))) + expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)), + rtol=1e-6, + atol=1e-6) self._assertOpOutputMatchesExpected( math_ops.reciprocal, @@ -441,13 +450,13 @@ class UnaryOpsTest(XLATestCase): np.array([[5j, 3 - 2j]], dtype=dtype), expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype))) - # TODO(b/34703906): improve log1p implementation and make tolerance - # tighter. self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), expected=np.log1p( - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)), + rtol=1e-4, + atol=1e-6) val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) self._assertOpOutputMatchesExpected( @@ -789,7 +798,9 @@ class UnaryOpsTest(XLATestCase): zero = np.asarray(0).astype(dtype) expected = np.logaddexp(zero, features) self._assertOpOutputMatchesExpected( - nn_ops.softplus, features, expected=expected) + nn_ops.softplus, features, expected=expected, + rtol=1e-6, + atol=9.1e-6) def testSoftplus(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index 8ecad00f6e23b3a7746bbb473102ac847bf4cbfd..2c09b03d5a35cde2c42d8a145781270c0c908587 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -187,6 +187,25 @@ class VariableOpsTest(XLATestCase): rtol=1e-4) self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4) + def testWriteOfAliasedTensor(self): + for dtype in self.numeric_types: + init = np.array([[1, 2j], [3, 4]]).astype(dtype) + update = np.array([[7, 1j], [2, 11]]).astype(dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + p = array_ops.placeholder(dtype) + q = array_ops.identity(p) + x = v.read_value() + # Writes the value of 'p' to 'v', but keeps a reference to the original + # value of 'v' so the variable update cannot reuse its buffer. + with ops.control_dependencies([x]): + y = v.assign(q) + result = sess.run([x, y, q], {p: update}) + self.assertAllClose(init, result[0]) + self.assertAllClose(update, result[1]) + self.assertAllClose(update, result[2]) + class StridedSliceAssignChecker(object): """Compares the results of a slice assignment using Tensorflow and numpy.""" diff --git a/tensorflow/compiler/tests/xla_device_gpu_test.py b/tensorflow/compiler/tests/xla_device_gpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1e30ebd55d09fe00449fb67b92a8325f5809d89a --- /dev/null +++ b/tensorflow/compiler/tests/xla_device_gpu_test.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for XLA devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class XlaDeviceGpuTest(test.TestCase): + + def testCopiesToAndFromGpuWork(self): + """Tests that copies between GPU and XLA devices work.""" + if not test.is_gpu_available(): + return + + with session_lib.Session() as sess: + x = array_ops.placeholder(dtypes.float32, [2]) + with ops.device("GPU"): + y = x * 2 + with ops.device("device:XLA_CPU:0"): + z = y * y + with ops.device("GPU"): + w = y + z + result = sess.run(w, {x: [1.5, 0.5]}) + self.assertAllClose(result, [12., 2.], rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index f5c228f8305d740b994dadc34c93b4e0ae32d785..f0b010fa67f2ffb3f81fd14d4d89585f716b4890 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,30 +18,40 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import dtypes +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test -class XlaDeviceTest(test.TestCase): +class XlaDeviceTest(XLATestCase): def testCopies(self): - """Tests that copies between GPU and XLA devices work.""" - if not test.is_gpu_available(): - return - - with session_lib.Session() as sess: - x = array_ops.placeholder(dtypes.float32, [2]) - with ops.device("GPU"): - y = x * 2 - with ops.device("device:XLA_CPU:0"): - z = y * y - with ops.device("GPU"): - w = y + z - result = sess.run(w, {x: [1.5, 0.5]}) - self.assertAllClose(result, [12., 2.], rtol=1e-3) + """Tests that copies onto and off XLA devices work.""" + shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3], + [16384, 1], [1, 16384], [1, 20000, 1, 1]] + for dtype in self.numeric_types: + for shape in shapes: + with self.test_session() as sess: + with ops.device("CPU"): + x = array_ops.placeholder(dtype, shape) + with self.test_scope(): + y = x + x + with ops.device("CPU"): + z = array_ops.identity(y) + + inputs = np.random.randint(-100, 100, shape).astype(dtype) + result = sess.run(z, {x: inputs}) + self.assertAllCloseAccordingToType(result, inputs + inputs) + + def testControlTrigger(self): + with self.test_session() as sess: + with self.test_scope(): + x = gen_control_flow_ops.control_trigger() + sess.run(x) if __name__ == "__main__": diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 4fca51f54d320e843343f80d7df1177f80f1d99f..cd57452302fcbde37d79ce760a80615a76d7ad8c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -325,6 +325,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:cpu_plugin", diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index 4f8bb8ad743afe69a6544c2ae0dc7309891b2df3..ea8d1b3d14939d4f4fba598318200f71c2eb0270 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -27,3 +27,25 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) + +tf_gen_op_wrapper_cc( + name = "xla_jit_op_gen", + out_ops_file = "ops/xla_jit_op", + deps = ["//tensorflow/compiler/jit/ops:xla_ops"], +) + +cc_library( + name = "xla_jit_ops", + srcs = ["ops/xla_jit_op.cc"], + hdrs = ["ops/xla_jit_op.h"], + deps = [ + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/jit/ops:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 8d1f2684909e876fe5521ba6a63d745c7d3956e0..1438f6b48c4913e60b0c0a9f5c3d67fe595cbfe8 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -282,7 +282,58 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, return Status::OK(); } -Status FunctionalizeLoop(Graph* graph, Frame* frame, +// Copy the FunctionDef of given function from lookup_library to library, if +// it can be found in lookup_library but is missing from library. +Status AddMissingFunctionByName(const string& function_name, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + if (!library->Find(function_name) && lookup_library->Find(function_name)) { + return library->AddFunctionDef(*lookup_library->Find(function_name)); + } + return Status::OK(); +} + +// Iterate over all functions that the given fdef refers to. Copy the missing +// FunctionDefs from lookup_library to library. +Status AddMissingFunctionDef(const FunctionDef& fdef, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + TF_RET_CHECK(lookup_library); + for (const NodeDef& node : fdef.node_def()) { + if (library->Find(node.op())) { + continue; + } + // The function refered by 'SymbolicGradient' node is specified in its + // attribute 'f'. + if (node.op() == FunctionLibraryDefinition::kGradientOp) { + const AttrValue* attr = + AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); + if (!attr) { + return errors::InvalidArgument("SymbolicGradient is missing attr: f"); + } + const string& func_name = attr->func().name(); + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(func_name, lookup_library, library)); + // Copy the user-defined gradient function if it exists. + const string grad_name = lookup_library->FindGradient(func_name); + if (!grad_name.empty() && library->FindGradient(func_name).empty()) { + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(grad_name, lookup_library, library)); + GradientDef grad_def; + grad_def.set_function_name(func_name); + grad_def.set_gradient_func(grad_name); + TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); + } + } else if (lookup_library->Find(node.op())) { + TF_RETURN_IF_ERROR( + library->AddFunctionDef(*lookup_library->Find(node.op()))); + } + } + return Status::OK(); +} + +Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " << dump_graph::DumpGraphToFile("functionalize_before", *graph, @@ -489,6 +540,14 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + if (lookup_library) { + // Copy missing FunctionDefs from lookup_library to library to make library + // self-contained. + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(cond_fdef, lookup_library, library)); + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(body_fdef, lookup_library, library)); + } // Builds a While operator. NodeDef while_def; @@ -1365,6 +1424,12 @@ Status FunctionalizeCond::Functionalize(Graph* graph, // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { + return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); +} + +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " << dump_graph::DumpGraphToFile("functionalize_initial", *graph, library); @@ -1373,7 +1438,13 @@ Status FunctionalizeControlFlow(Graph* graph, // connected to all source nodes in the graph. Many graphs violate this // invariant. std::vector cf_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info)); + std::vector unreachable_nodes; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes)); + if (!unreachable_nodes.empty()) { + return errors::InvalidArgument( + "The following nodes are unreachable from the source in the graph: ", + tensorflow::str_util::Join(unreachable_nodes, ", ")); + } // Builds Frames, indexed by name. std::unordered_map frames; @@ -1434,7 +1505,8 @@ Status FunctionalizeControlFlow(Graph* graph, continue; } - TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 4d4ee3054c2914bb614bf75f7a51be8f6292683e..d941041d15532446d1413f16fe64602bfb1a7daa 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -22,9 +22,13 @@ limitations under the License. namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While -// operators, suitable for XLA compilation. +// operators, suitable for XLA compilation. If lookup_library is provided, use +// it to make the library for control flow self-contained. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index e494f42e8ed254ac0c7c7a23a13728d3f015e9d3..14977a908ae2b0ff7e13b634c41b6d331b4b8a36 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -299,6 +299,131 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } } +// @function.Defun(noinline=True) +// def increment_fn(x): +// return [x + 1] +// Define the above function, and add it to the given graph. It's used as the +// while loop body in NoinlineLoopBody test. +Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { + FunctionDef fdef = FunctionDefHelper::Create( + "increment_fn", {"x:int32"}, {"add:int32"}, {}, + { + {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}}, + {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}}, + }, + {{"add", "add_0:z:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); + NodeDef increment_fn; + increment_fn.set_name(node_name); + increment_fn.set_op("increment_fn"); + *increment_fn.add_input() = "while/Identity"; + *increment_fn.add_input() = "^while/Identity"; + Status status; + graph->AddNode(increment_fn, &status); + return status; +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x]) +TEST(FunctionalizeControlFlow, NoinlineLoopBody) { + const string& noinline_node_name = "while/increment_fn"; + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source, + "while/while_context"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_.output_false); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + + NodeDef next_iter; + next_iter.set_name("while/NextIteration"); + next_iter.set_op("NextIteration"); + *next_iter.add_input() = noinline_node_name; + (*next_iter.mutable_attr())["T"].set_type(DT_INT32); + + Status status; + Node* n = scope.graph()->AddNode(next_iter, &status); + TF_ASSERT_OK(status); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(n, 0, merge.output.node(), 1); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition lookup_lib(graph.flib_def()); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + // Function increment_fn will be copied from lookup_lib to library. + TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + NodeDef retval; + retval.set_name("_retval0_RetVal"); + retval.set_op(FunctionLibraryDefinition::kRetOp); + *retval.add_input() = noinline_node_name; + (*retval.mutable_attr())["T"].set_type(DT_INT32); + (*retval.mutable_attr())["index"].set_i(0); + Status status; + scope.graph()->AddNode(retval, &status); + TF_ASSERT_OK(status); + + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + // Verify that increment_fn has been copied to library. + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + // Ignore the function library when comparing the graphs. + expected.clear_library(); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Tests functionalizing OneLoopVar where the loop value is not used post the // loop. // Graph: diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 8115a26210a8e9e95e851f350e34dcdfa2519a64..212f6f3966149ca0b2d2e012b19300e1f488f996 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -208,10 +208,11 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RETURN_IF_ERROR( PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; XlaCompiler::CompilationResult result; - - TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(), - func, arguments, &result)); + TF_RETURN_IF_ERROR( + compiler->CompileFunction(compile_options, func, arguments, &result)); TF_RET_CHECK(arguments.size() == expressions.size()); @@ -229,11 +230,14 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, auto output_handle = b->Call(*result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so // that it can fit into future computations. + int computation_output = 0; for (int64 i = 0; i < n->num_outputs(); ++i) { if (result.outputs[i].is_constant) { xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); } else { - xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i)); + xla_op_context.SetOutput( + i, b->GetTupleElement(output_handle, computation_output)); + ++computation_output; } } return b->first_error(); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 85ab4c41bf6a754236066260819f103970e603ae..edd2ab6301ee891c433639ce300cde0c72929cea 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -18,6 +18,7 @@ tf_kernel_library( "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", + "bucketize_op.cc", "cast_op.cc", "categorical_op.cc", "cholesky_op.cc", @@ -45,6 +46,7 @@ tf_kernel_library( "image_resize_ops.cc", "index_ops.cc", "l2loss_op.cc", + "listdiff_op.cc", "lrn_ops.cc", "matmul_op.cc", "matrix_band_part_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca9a6b40688d1e8496d1b823e20d273d519f65e8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BucketizeOp : public XlaOpKernel { + public: + explicit BucketizeOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("boundaries", &boundaries_)); + OP_REQUIRES(context, std::is_sorted(boundaries_.begin(), boundaries_.end()), + errors::InvalidArgument("Expected sorted boundaries")); + } + + void Compile(XlaOpKernelContext* context) override { + xla::XlaBuilder* builder = context->builder(); + const DataType dtype = context->input_type(0); + xla::XlaOp input = context->Input(0); + + xla::XlaOp boundaries = builder->ConstantR1(boundaries_); + // TODO(phawkins): the following behavior matches the behavior of the core + // Bucketize kernel. However, comparing an int32 or int64 against float may + // lead to inaccurate bucketing due to rounding. + if (dtype == DT_DOUBLE) { + input = builder->ConvertElementType(input, xla::F64); + boundaries = builder->ConvertElementType(boundaries, xla::F64); + } else { + input = builder->ConvertElementType(input, xla::F32); + } + xla::XlaOp comparison = builder->ConvertElementType( + builder->Ge(builder->Broadcast(input, {1}), boundaries, + /*broadcast_dimensions=*/{0}), + xla::S32); + xla::XlaOp buckets = builder->Reduce( + comparison, /*init_value=*/builder->ConstantR0(0), + /*computation=*/xla::CreateScalarAddComputation(xla::S32, builder), + /*dimensions_to_reduce=*/{0}); + context->SetOutput(0, buckets); + } + + private: + std::vector boundaries_; +}; + +REGISTER_XLA_OP(Name("Bucketize"), BucketizeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index ed7462c16615f7f63a174e29843c2a1675c17058..493781a1e68b8906f1a7e018e5710130e2eb08b5 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -34,9 +34,8 @@ class EluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); } }; @@ -68,13 +67,12 @@ class SeluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), 1.7580993408473768599402175208123); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), b->Mul(scale_alpha, expm1))); } diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 8b9b026643cf35216a2082dfcce9270c017bd14f..d48c6eea754f75a8879d3938f233a6a591d26d0d 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -48,11 +48,11 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; - std::vector inputs(input_types_.size()); std::vector arguments(input_types_.size()); for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; DataType type = ctx->input_type(i + 1); + if (type == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); @@ -60,7 +60,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.initialized = resource->initialized(); arg.kind = XlaCompiler::Argument::kResource; arg.resource_kind = resource->kind(); - OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); arg.type = resource->type(); arg.shape = resource->shape(); @@ -79,7 +78,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; arg.shape = ctx->InputShape(i + 1); - inputs[i] = ctx->Input(i + 1); VLOG(2) << "Arg type: " << DataTypeString(arg.type) << " shape: " << arg.shape.DebugString(); } @@ -100,6 +98,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, arguments, &else_result)); + bool has_tensor_array_gradients = false; for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { XlaResource* resource; @@ -121,9 +120,21 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } + if (!resource->tensor_array_gradients().empty()) + has_tensor_array_gradients = true; } } + // Recompile the functions to update the argument shapes for tensor arrays. + if (has_tensor_array_gradients) { + then_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, + arguments, &then_result)); + else_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, + arguments, &else_result)); + } + // Check that both branches have identical input shapes. OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); @@ -175,6 +186,19 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { "Mismatch in resource of then and else branch for resource ", i)); } + int num_inputs = then_result.input_mapping.size(); + std::vector inputs(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + int input_num = then_result.input_mapping[i] + 1; + if (ctx->input_type(input_num) == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); + } else { + inputs[i] = ctx->Input(i + 1); + } + } + xla::XlaOp outputs = b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, b->Tuple(inputs), *else_result.computation); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 9058cbc74762576c7e6f8ec1b2b0f6b247ac0502..79d3a6979cec4c6bda92a71dcff4ddd2151367d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -99,27 +99,34 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( return dims; } +// Form a 2D convolution kernel like: +// 1 2 3 2 1 +// 2 4 6 4 2 +// 1/9 * 3 6 9 6 3 +// 2 4 6 4 2 +// 1 2 3 2 1 +// by multiplying two 1D kernels of the form: +// 1/3 * [1 2 3 2 1] +// If the 2D kernel would be very large, the 1D kernel can be applied once in +// each dimension due to the symmetry of the kernel along all axis to reduce the +// computational intensity. +std::vector Make1DKernel(int64 n) { + std::vector kernel(n * 2 - 1); + for (int64 i = 0; i < n; ++i) { + float v = (i + 1.0f) / n; + kernel[i] = v; + kernel[n * 2 - 2 - i] = v; + } + return kernel; +} + +// Kernels with more than 16 spatial elements are considered intense and the +// kernel should applied to each dimension independently. +const int64 kMax2DKernelSize = 16; + xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, gtl::ArraySlice kernel_size, int64 channels) { - // Form a 2D convolution kernel like: - // 1 2 3 2 1 - // 2 4 6 4 2 - // 1/9 * 3 6 9 6 3 - // 2 4 6 4 2 - // 1 2 3 2 1 - // by multiplying two 1D kernels of the form: - // 1/3 * [1 2 3 2 1] - auto make_1d_kernel = [](int64 n) { - std::vector kernel(n * 2 - 1); - for (int64 i = 0; i < n; ++i) { - float v = (i + 1.0f) / n; - kernel[i] = v; - kernel[n * 2 - 2 - i] = v; - } - return kernel; - }; - xla::XlaOp channels_iota; // DT_INT32 Iota will always return status::OK(). TF_CHECK_OK( @@ -133,12 +140,37 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, xla::PrimitiveType::F32); return builder->Mul( builder->Mul(diag, - builder->ConstantR1(make_1d_kernel(kernel_size[1])), + builder->ConstantR1(Make1DKernel(kernel_size[1])), /*broadcast_dimensions=*/{1}), - builder->ConstantR1(make_1d_kernel(kernel_size[0])), + builder->ConstantR1(Make1DKernel(kernel_size[0])), /*broadcast_dimensions=*/{0}); } +xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, + gtl::ArraySlice kernel_size, + int64 channels, int64 dim) { + xla::XlaOp channels_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK( + XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + + auto diag = builder->ConvertElementType( + builder->Eq(builder->Broadcast( + channels_iota, + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), + xla::PrimitiveType::F32); + if (dim == 1) { + return builder->Mul( + diag, builder->ConstantR1(Make1DKernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}); + } + return builder->Mul(diag, + builder->ConstantR1(Make1DKernel(kernel_size[0])), + /*broadcast_dimensions=*/{0}); +} + xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, const xla::XlaOp& input, const int num_spatial_dims, @@ -165,20 +197,42 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, dimension_numbers.add_output_spatial_dimensions(1 + i); dimension_numbers.add_kernel_spatial_dimensions(i); } - dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); - dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, out_size); - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - xla::XlaOp output = builder->ConvGeneralDilated( - input, kernel, dims.stride, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp output; + // Split convolutions into independent dimensions if they wmuld be a very + // large kernel. + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + output = builder->ConvGeneralDilated( + input, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + output = builder->ConvGeneralDilated( + input, kernel0, {dims.stride[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.kernel_size[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + output = builder->ConvGeneralDilated( + output, kernel1, {1, dims.stride[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.kernel_size[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // Add broadcasts to handle expanding from a size == 1 dimension to a // size > 1 dimension. @@ -214,26 +268,63 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, } dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp output; + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && grad_size[i] > 1) { + kernel = + builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), + /*broadcast_dimensions=*/{i}); + } + } - // Broadcast the input kernel where the forward op expanded from a size == 1 - // dimension to a size > 1 dimension. This has the effect of summing the - // gradient contributions in that dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && grad_size[i] > 1) { - kernel = builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), - /*broadcast_dimensions=*/{i}); + output = builder->ConvGeneralDilated( + grad, kernel, /*window_strides=*/dims.kernel_size, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.stride, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + if (in_size[0] == 1 && grad_size[0] > 1) { + kernel0 = + builder->Add(kernel0, builder->ConstantR1(grad_size[0], 0), + /*broadcast_dimensions=*/{0}); + } + if (in_size[1] == 1 && grad_size[1] > 1) { + kernel1 = + builder->Add(kernel0, builder->ConstantR1(grad_size[1], 0), + /*broadcast_dimensions=*/{1}); } - } - xla::XlaOp output = builder->ConvGeneralDilated( - grad, kernel, /*window_strides=*/dims.kernel_size, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.stride, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + output = builder->ConvGeneralDilated( + grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.stride[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + + output = builder->ConvGeneralDilated( + output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.stride[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. // Opposite of the slice performed by the forward op. diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0388b4c830702ea00ec69fc42c6468326c88cf38 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific ListDiff Op. This only supports constant DT_INT32 and DT_INT64 +// input. + +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +constexpr std::array kListDiffTypes = {DT_INT32, DT_INT64}; + +// ListDiffOp is an XLA kernel that supports constant-only x and y input. +class ListDiffOp : public XlaOpKernel { + public: + explicit ListDiffOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(0)), + errors::InvalidArgument("ListDiff expects x as a vector, not ", + context->InputShape(0).DebugString())); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(1)), + errors::InvalidArgument("ListDiff expects y as a vector, not ", + context->InputShape(1).DebugString())); + + DataType val_type = context->expected_output_dtype(0); + DataType idx_type = context->expected_output_dtype(1); + + Status status; + switch (val_type) { + case DT_INT32: + status = ListDiffWithIndexType(context, idx_type); + break; + case DT_INT64: + status = ListDiffWithIndexType(context, idx_type); + break; + default: + // This should never happen since we restrict this kernel to only match + // inputs with supported Tensor datatype. + status = errors::InvalidArgument("ListDiff expects x and y as either ", + "int32 or int64, not ", + DataTypeString(val_type)); + } + OP_REQUIRES_OK(context, status); + } + + private: + template + Status ListDiff(XlaOpKernelContext* context) { + std::vector x_input, y_input; + TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(0, &x_input)); + TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(1, &y_input)); + + std::unordered_set y_input_set; + y_input_set.reserve(y_input.size()); + for (auto y : y_input) { + y_input_set.insert(y); + } + + std::vector val_output; + std::vector idx_output; + auto x_size = x_input.size(); + for (Tidx i = 0; i < x_size; ++i) { + if (y_input_set.count(x_input[i]) > 0) { + continue; + } + val_output.push_back(x_input[i]); + idx_output.push_back(i); + } + + context->SetOutput(0, context->builder()->ConstantR1(val_output)); + context->SetOutput(1, context->builder()->ConstantR1(idx_output)); + return Status::OK(); + } + + template + Status ListDiffWithIndexType(XlaOpKernelContext* context, DataType idx_type) { + switch (idx_type) { + case DT_INT32: + return ListDiff(context); + case DT_INT64: + return ListDiff(context); + default: + return errors::InvalidArgument( + "ListDiff expects idx_out as either int32 or int64, not ", + DataTypeString(idx_type)); + } + } +}; + +REGISTER_XLA_OP(Name("ListDiff") + .TypeConstraint("T", kListDiffTypes) + .CompileTimeConstInput("x") + .CompileTimeConstInput("y"), + ListDiffOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc index 8c8a9bbe787f3224e7444b62dcf8ad99130cf37f..65ab9da8d7ca0509a4a69c43727a0e6c0435908a 100644 --- a/tensorflow/compiler/tf2xla/kernels/no_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc @@ -24,8 +24,7 @@ namespace tensorflow { REGISTER_XLA_OP(Name("NoOp").CompilationOnly(), NoOp); // We register ControlTrigger as a no-op. This is correct since nodes seen -// by the XLA compiler are never dead. This may need rethinking when we add -// support for conditionals to XLA. -REGISTER_XLA_OP(Name("ControlTrigger"), NoOp); +// by the XLA compiler are never dead. +REGISTER_XLA_OP(Name("ControlTrigger").CompilationOnly(), NoOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 5f5bd586376ab368e443671ac8a5de23a5fd604b..105be38fe26b6667e8b4ce6da92a3969cdc0c187 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,6 +17,9 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -55,6 +58,78 @@ class RandomUniformOp : public XlaOpKernel { REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), RandomUniformOp); +class RandomShuffleOp : public XlaOpKernel { + public: + explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + TensorShape input_shape = ctx->InputShape(0); + const int64 n = input_shape.dim_size(0); + int64 num_elements = 1; + for (tensorflow::TensorShapeDim dimension : input_shape) { + num_elements *= dimension.size; + } + if (num_elements <= 1 || n <= 1) { + // No shuffling is required, so copy input directly to output + ctx->SetOutput(0, input); + } else { + // Generate the random swaps for the indices. + auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); + auto swaps = + builder->RngUniform(builder->ConstantR0(0), + builder->ConstantR0(n), swaps_shape); + + // Generate range(n) as the initial value for the indices to be swapped. + xla::XlaOp indices; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices)); + + // Swap the indices at i and swaps[i]. + auto swap_body_fn = [&](xla::XlaOp i, + gtl::ArraySlice loop_vars, + xla::XlaBuilder* builder) + -> xla::StatusOr> { + auto swaps = loop_vars[0]; + auto indices = loop_vars[1]; + i = builder->Reshape(i, {1}); + // temp = indices[i] + auto temp = builder->DynamicSlice(indices, i, {1}); + // swap_index = swaps[i] + auto swap_index = builder->DynamicSlice(swaps, i, {1}); + // swap_value = indices[swaps[i]] + auto swap_value = builder->DynamicSlice(indices, swap_index, {1}); + // indices[i] = indices[swaps[i]] + indices = builder->DynamicUpdateSlice(indices, swap_value, i); + // indices[swaps[i]] = temp + indices = builder->DynamicUpdateSlice(indices, temp, swap_index); + return std::vector{swaps, indices}; + }; + // for i in range(n): + auto swap_loop_result = + XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) + .ValueOrDie(); + auto swapped_indices = swap_loop_result[1]; + + // Gather the data using the swapped indices as the shuffled order. + auto indices_tensor_shape = TensorShape({n}); + DataType type = ctx->expected_output_dtype(0); + xla::XlaOp gather; + OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, + indices_tensor_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, + DT_INT32, builder, &gather)); + ctx->SetOutput(0, gather); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp); +}; + +REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp); + class RandomUniformIntOp : public XlaOpKernel { public: explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -127,13 +202,8 @@ class TruncatedNormalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::Shape xla_element_shape = - xla::ShapeUtil::MakeShape(xla_shape.element_type(), {}); xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp mean = XlaHelpers::Zero(b, dtype); - xla::XlaOp stddev = XlaHelpers::One(b, dtype); - xla::XlaOp candidate = b->RngNormal(mean, stddev, xla_shape); auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) { return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); @@ -151,34 +221,38 @@ class TruncatedNormalOp : public XlaOpKernel { // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd // candidate = select(out_of_range_mask, rng_normal(), candidate) // } - std::unique_ptr test_builder = - b->CreateSubBuilder("truncated_normal_test"); - { - auto* b = test_builder.get(); - xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate"); - out_of_range_mask(candidate, b); - OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status()); - } - - std::unique_ptr body_builder = - b->CreateSubBuilder("truncated_normal_body"); - { - auto* b = body_builder.get(); - xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate"); - xla::XlaOp to_resample = out_of_range_mask(candidate, b); + std::vector initial_values = { + // The current candidate. + b->Broadcast(XlaHelpers::Zero(b, dtype), shape.dim_sizes()), + // The to_resample mask, where 'true' identifies a location in the + // current candidate that is out of range and must be regenerated. + b->Broadcast(b->ConstantR0(true), shape.dim_sizes()), + // Is any element in the mask true? + b->ConstantR0(true)}; + auto condition = [&](gtl::ArraySlice values, + xla::XlaBuilder* b) -> xla::StatusOr { + // Continue while any element in the mask is true. + return values[2]; + }; + auto body = + [&](gtl::ArraySlice values, + xla::XlaBuilder* b) -> xla::StatusOr> { + xla::XlaOp candidate = values[0]; + xla::XlaOp to_resample = values[1]; xla::XlaOp mean = XlaHelpers::Zero(b, dtype); xla::XlaOp stddev = XlaHelpers::One(b, dtype); - b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate); - } - - xla::StatusOr test_computation = test_builder->Build(); - OP_REQUIRES_OK(ctx, test_computation.status()); - xla::StatusOr body_computation = body_builder->Build(); - OP_REQUIRES_OK(ctx, body_computation.status()); - xla::XlaOp result = b->While(test_computation.ValueOrDie(), - body_computation.ValueOrDie(), candidate); - - ctx->SetOutput(0, result); + candidate = b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), + candidate); + // Compute a new to_resample mask, and determine whether any value is + // still out of range. + to_resample = out_of_range_mask(candidate, b); + TF_ASSIGN_OR_RETURN(xla::XlaOp done, Any(to_resample, b)); + return std::vector{candidate, to_resample, done}; + }; + auto result = + XlaWhileLoop(condition, body, initial_values, "truncated_normal", b); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()[0]); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 70547290eaed169599764a5d66185dde85345863..a711278638444be01fb865561957702368b75114 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -55,18 +55,33 @@ class RetvalOp : public XlaOpKernel { } XlaContext& tc = XlaContext::Get(ctx); - if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) { + if (tc.resolve_compile_time_constants() && + (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { - // The core from which a return value is returned depends on the core - // assignment of the input to the retval .Since we can't change the core - // assignment of as this point, create a tuple/get-tuple-element - // combination so that the core will be set on them. - auto tuple_elem = - ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0); - tc.AddRetval(index_, dtype_, tuple_elem); + TensorShape shape = ctx->InputShape(0); + TensorShape representation_shape = + tc.is_entry_computation() + ? tc.RepresentationShape(shape, ctx->input_type(0)) + : shape; + + xla::XlaOp output = input; + if (tc.is_entry_computation()) { + output = + ctx->builder()->Reshape(input, representation_shape.dim_sizes()); + } else { + // The core from which a return value is returned depends on the + // device assignment of the input to the retval. Since we can't change + // the device assignment of "input" at this point, we must always + // introduce an operator here, even if the shape does not change. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + output = ctx->builder()->GetTupleElement( + ctx->builder()->Tuple({output}), 0); + } + tc.AddRetval(index_, dtype_, shape, output); } } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 0ed4c4707df71cf5f56ccfe0af506916f04bcdb5..5d1c05268493f4f6404c40a4092a71f1e5b3f3b9 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -106,20 +106,40 @@ class ReverseSequenceOp : public XlaOpKernel { seq_lens, body_builder->Reshape(i, {1}), {1}); // Indices is the offset of the batch element in the input. - auto indices = body_builder->Broadcast( + auto batch_element_indices = body_builder->Broadcast( XlaHelpers::Zero(body_builder.get(), seq_lens_type), {input_shape.dims()}); - indices = body_builder->DynamicUpdateSlice( - indices, body_builder->Reshape(i, {1}), + batch_element_indices = body_builder->DynamicUpdateSlice( + batch_element_indices, body_builder->Reshape(i, {1}), body_builder->Reshape( XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, batch_dim_), {1})); - // slice_indices is the offset of the start of the reversed sequence in - // the input. - auto slice_indices = body_builder->DynamicUpdateSlice( - indices, + // Slice out the current batch element and pad it out in the sequence + // dimension. + TensorShape slice_shape = input_shape; + slice_shape.set_dim(batch_dim_, 1); + slice_shape.set_dim(seq_dim_, max_seq_len); + auto slice = body_builder->DynamicSlice(output, batch_element_indices, + slice_shape.dim_sizes()); + auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); + padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( + slice_shape.dim_size(seq_dim_)); + slice = body_builder->Pad( + slice, XlaHelpers::Zero(body_builder.get(), input_type), + padding_config); + + // Now slice out the reversed sequence from its actual start. + // sequence_start_indices is the offset of the start of the reversed + // sequence in the input. The slice will go into the padding, however, we + // will mask off these elements and replace them with elements from the + // original input so their values do not matter. + auto sequence_start_indices = body_builder->Broadcast( + XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {slice_shape.dims()}); + sequence_start_indices = body_builder->DynamicUpdateSlice( + sequence_start_indices, body_builder->Sub(XlaHelpers::IntegerLiteral( body_builder.get(), seq_lens_type, max_seq_len), seq_len), @@ -127,18 +147,12 @@ class ReverseSequenceOp : public XlaOpKernel { XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, seq_dim_), {1})); - - // Slice out the reversed sequence. The slice will overflow the end of the - // sequence, and the contents of the overflow are implementation-defined. - // However, we will mask off these elements and replace them with elements - // from the original input so their values do not matter. - TensorShape slice_shape = input_shape; - slice_shape.set_dim(batch_dim_, 1); - auto slice = body_builder->DynamicSlice(output, slice_indices, - slice_shape.dim_sizes()); + slice = body_builder->DynamicSlice(slice, sequence_start_indices, + slice_shape.dim_sizes()); // Shift the reversed sequence to the left. - output = body_builder->DynamicUpdateSlice(output, slice, indices); + output = body_builder->DynamicUpdateSlice(output, slice, + batch_element_indices); body_builder->Tuple( {body_builder->Add( diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 05354bca5bb089703fdcceb6f44648bbb98d004b..d59720bef742c7441ee01a954247013559bb909c 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -43,7 +43,7 @@ class ShapeOp : public XlaOpKernel { DataType out_dtype_; }; -REGISTER_XLA_OP(Name("Shape"), ShapeOp); +REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -65,7 +65,7 @@ class ShapeNOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -81,7 +81,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Rank"), RankOp); +REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp); class SizeOp : public XlaOpKernel { public: @@ -100,7 +100,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Size"), SizeOp); +REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: @@ -189,10 +189,9 @@ class SqueezeOp : public XlaOpKernel { if (!wrapped_squeeze_dims.empty()) { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, - errors::InvalidArgument("Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", - existing_dim)); + errors::InvalidArgument( + "Tried to explicitly squeeze dimension ", i, + " but dimension was not 1: ", existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index a4f50f52ebe8b1ed7df862996d64e135ea1d0ac5..71a9fd051bfc8db09738a4bfe8ddde447895ecf0 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -100,8 +100,7 @@ XLAJIT_MAKE_UNARY(Cosh, XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); -// TODO(b/34703906): use a more accurate implementation of expm1. -XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0)))); +XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x)); XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x)); @@ -115,8 +114,7 @@ XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Log, b->Log(x)); -// TODO(b/34703906): use a more accurate implementation of log1p. -XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); +XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x)); XLAJIT_MAKE_UNARY(Invert, b->Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); @@ -160,24 +158,17 @@ XLAJIT_MAKE_UNARY(Sinh, b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -static xla::XlaOp Softplus(xla::XlaBuilder* b, DataType dtype, - const xla::XlaOp& features) { - xla::XlaOp threshold = b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)), - XlaHelpers::FloatLiteral(b, dtype, 2.0)); - // Value above which exp(x) may overflow, but softplus(x) == x - // is within machine epsilon. - xla::XlaOp too_large = b->Gt(features, b->Neg(threshold)); - // Value below which exp(x) may underflow, but softplus(x) == exp(x) - // is within machine epsilon. - xla::XlaOp too_small = b->Lt(features, threshold); - xla::XlaOp features_exp = b->Exp(features); - xla::XlaOp output = b->Select( - too_large, features, - b->Select(too_small, features_exp, - b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype))))); - return output; -} -XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x)); +// softplus(x) = log(1 + exp(x)) +// +// This is not numerically stable when x is large, it can easily overflow. +// However, we can compute it as LogSumExp(x, 0): +// max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0))) +// +// This is equivalent to: +// max(x, 0) + log1p(exp(-abs(x))) +XLAJIT_MAKE_UNARY(Softplus, + b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))), + b->Log1p(b->Exp(b->Neg(b->Abs(x)))))); // softsign(x) = x / (abs(x) + 1) XLAJIT_MAKE_UNARY(Softsign, diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 6109db8e89e5ee67e0635d26e258bfe7cb70a15d..a163fa0a5b34675e46d0d7c5f4e0ccb1e3fb18eb 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -57,7 +57,7 @@ class ReadVariableOp : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp); +REGISTER_XLA_OP(Name("ReadVariableOp").CompilationOnly(), ReadVariableOp); class AssignVariableOp : public XlaOpKernel { public: @@ -67,7 +67,7 @@ class AssignVariableOp : public XlaOpKernel { ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1))); } }; -REGISTER_XLA_OP(Name("AssignVariableOp"), AssignVariableOp); +REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp); class AssignAddVariableOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 04ad3694a0c0df9d43c706d428c3b8715e5ff8ca..ee7f5d510ab7a3ce7d3bbe843c5fefd362f79b7b 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -80,7 +80,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -141,7 +140,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/tests:client_library_test_base", diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 83e73827862ca26a1a51bed72ab87768854c1e71..20925118bf598a6436c43bd727ce40e3abafc46c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -110,7 +110,6 @@ xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, FloatLiteral(body_builder, a_shape.element_type(), 0.5)); // a[..., i+1:, i] - auto ip1 = body_builder->Add(i, body_builder->ConstantR0(1)); // select the whole i-th column, then mask out all rows above i+1 TF_ASSIGN_OR_RETURN( auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); @@ -214,7 +213,7 @@ xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, /*lower=*/true, /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/8)); + /*block_size=*/block_size)); TF_ASSIGN_OR_RETURN( l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); } diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 2c3cd658e0462368ac0b51938979b7a6815a7574..db56b128375ce8ff2faf12c5d7ea256bdfab0f63 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,7 +40,38 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + *literal = xla::BorrowingLiteral( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); + return Status::OK(); +} + +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal) { + std::vector buf_ptrs; + buf_ptrs.reserve(host_tensors.size()); + std::vector tensor_shapes(host_tensors.size()); + + for (int i = 0; i < host_tensors.size(); i++) { + // Validate runtime shapes and fail if it doesn't match the contract. + const Tensor* tensor = &host_tensors[i]; + buf_ptrs.emplace_back(static_cast(DMAHelper::base(tensor))); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(), + &tensor_shapes[i])); + } + + *literal = xla::BorrowingLiteral( + buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes)); + + return Status::OK(); +} + +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && xla::ShapeUtil::ElementsIn(literal.shape()) == @@ -63,8 +94,8 @@ Status CopyLiteralToHostTensor(const xla::Literal& literal, return Status::OK(); } -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor) { +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor) { TensorShape shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); *host_tensor = Tensor(target_type, shape); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index f283b0236811f8d52e8fe2982a74c11c92cd20d8..74685025c1780c5c0ba56205a98786582e9191e9 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -29,6 +30,17 @@ namespace tensorflow { // unsupported type. Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); +// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by +// 'host_tensor'. +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal); + +// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers +// owned by 'host_tensors'. +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal); + // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . // Fails if the literal's primitive type != @@ -36,13 +48,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); // derivable from the type of , because multiple tensorflow types map // to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in // XLA). -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor); +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor); // Copies the contents of 'literal' to a previously allocated tensor // 'host_tensor'. The tensor and the literal must have the same number of // elements and the same type. -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 3a08aa8cf4f5cea6210cc9470d57c3387445ea6e..ac768b206e2a8d163a4253432a1911152f89ce86 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -263,8 +263,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - DeviceType device_type(DEVICE_CPU_XLA_JIT); - compiler_options.device_type = &device_type; + compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); compiler_options.flib_def = &graph->flib_def(); compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 3d1946c332b0f903b710a19fbb79fc9923e89c43..9c8e56a17e07348d3cfaaca0b5eb335295af05c3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -15,10 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include #include +#include -#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -28,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" @@ -40,7 +38,6 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -86,12 +83,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), next_step_id_(1), - device_( - new XlaCompilationDevice(SessionOptions(), *options_.device_type)), + device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), device_mgr_({device_}) { - // We no longer need the device_type. - options_.device_type = nullptr; - + CHECK(!options_.device_type.type_string().empty()); if (options_.populate_resource_manager) { initialization_status_ = (*options_.populate_resource_manager)(device_->resource_manager()); @@ -110,10 +104,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name()); - // The default variable representation shape is the identity function. - if (!options_.variable_representation_shape_fn) { - options_.variable_representation_shape_fn = - [](const TensorShape& shape, DataType type) { return shape; }; + // The default shape representation function is the identity. + if (!options_.shape_representation_fn) { + options_.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return shape; }; } } @@ -230,20 +224,25 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // Computes the XLA shape for argument 'arg'. Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, - xla::Shape* xla_shape) { + bool is_entry_computation, + xla::Shape* xla_shape) const { switch (arg.kind) { case XlaCompiler::Argument::kConstant: - return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), - xla_shape); - case XlaCompiler::Argument::kParameter: - return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + LOG(FATAL) << "Unreachable case"; + case XlaCompiler::Argument::kParameter: { + TensorShape shape = + is_entry_computation + ? options_.shape_representation_fn(arg.shape, arg.type) + : arg.shape; + return TensorShapeToXLAShape(arg.type, shape, xla_shape); + } case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { case XlaResource::kVariable: { TensorShape representation_shape = - options_.variable_representation_shape_fn(arg.shape, arg.type); + options_.shape_representation_fn(arg.shape, arg.type); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } @@ -337,16 +336,25 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, Status BuildComputation( const std::vector& args, const std::vector& arg_cores, - const std::vector& retvals, + const std::vector& retvals, const std::vector>& resources, bool return_updated_values_for_all_resources, xla::XlaBuilder* builder, xla::XlaComputation* computation, int* num_computation_outputs, int* num_nonconst_outputs, + std::vector* outputs, std::vector* resource_updates) { std::vector elems; elems.reserve(retvals.size()); - for (const XlaExpression& retval : retvals) { - if (!retval.has_constant_value()) { + for (int i = 0; i < retvals.size(); ++i) { + XlaCompiler::OutputDescription& output = (*outputs)[i]; + output.type = retvals[i].type; + output.shape = retvals[i].shape; + const XlaExpression& retval = retvals[i].expression; + if (retval.has_constant_value()) { + output.is_constant = true; + output.constant_value = retval.constant_value(); + } else { + output.is_constant = false; elems.push_back(retval.handle()); } } @@ -490,8 +498,8 @@ Status XlaCompiler::BuildArguments( std::vector arg_shapes(input_mapping->size()); for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - TF_RETURN_IF_ERROR( - XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i])); + TF_RETURN_IF_ERROR(XLAShapeForArgument( + args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i])); } if (use_tuple_arg) { @@ -567,7 +575,8 @@ Status XlaCompiler::BuildArguments( builder->ClearOpMetadata(); - // Fill in the handles in non-constant arguments. + // Fill in the handles in non-constant arguments, and reshape parameters + // back to their correct shapes. VLOG(2) << "XLA computation inputs:"; for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; @@ -586,7 +595,15 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kParameter: - arg_expression.set_handle(arg_handles[i]); + // Reshape parameters back to their correct shapes. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + if (is_entry_computation) { + arg_expression.set_handle( + builder->Reshape(arg_handles[i], arg.shape.dim_sizes())); + } else { + arg_expression.set_handle(arg_handles[i]); + } break; case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: @@ -635,10 +652,70 @@ Status XlaCompiler::CompileSingleOp( .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); } + FixupSourceAndSinkEdges(graph.get()); return CompileGraph(options, name, std::move(graph), args, result); } +namespace { + +// Check that the ops of all non-functional nodes have been registered. +string ValidateFunctionDef(const FunctionDef* fdef, + const FunctionLibraryDefinition& flib_def) { + std::vector invalid_ops; + for (const NodeDef& node : fdef->node_def()) { + const string& op = node.op(); + if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) { + invalid_ops.push_back(op); + } + } + return tensorflow::str_util::Join(invalid_ops, ", "); +} + +// Check that the graph doesn't have any invalid nodes (e.g. incompatible with +// given device_type, invalid data type, missing attributes...) +Status ValidateGraph(const Graph* graph, + const FunctionLibraryDefinition& flib_def, + const DeviceType& device_type, const string& name) { + std::vector invalid_ops; + for (const Node* node : graph->nodes()) { + if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { + continue; + } + const FunctionDef* fdef = flib_def.Find(node->def().op()); + if (fdef) { + string error_msg = ValidateFunctionDef(fdef, flib_def); + if (!error_msg.empty()) { + invalid_ops.push_back( + strings::StrCat(node->def().op(), ":{", error_msg, "}")); + } + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) { + invalid_ops.push_back(node->def().op()); + continue; + } + TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); + if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { + invalid_ops.push_back(node->def().op()); + } + } + if (!invalid_ops.empty()) { + return errors::InvalidArgument(strings::StrCat( + "Detected unsupported operations when trying to compile graph ", name, + " on ", device_type.type_string(), ":", + tensorflow::str_util::Join(invalid_ops, ", "))); + } + return Status::OK(); +} + +} // namespace + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -658,13 +735,19 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Converts Tensorflow's graph control-flow constructs into functional // control-flow that can be compiled into XLA code. TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(graph.get(), local_flib_def_.get())); + FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), + graph.get(), local_flib_def_.get())); + + // Detect invalid nodes. + // FunctionalizeControlFlow may remove some nodes from the graph. + TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, + options_.device_type, name)); xla::XlaBuilder builder(name); - XlaContext* context = - new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, - &options_.variable_representation_shape_fn); + XlaContext* context = new XlaContext( + this, &builder, options_.allow_cpu_custom_calls, + options.resolve_compile_time_constants, options.is_entry_computation, + &options_.shape_representation_fn); core::ScopedUnref context_unref(context); std::vector arg_expressions; @@ -681,35 +764,22 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_nonconst_outputs; int num_computation_outputs; result->computation = std::make_shared(); + result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, - &num_nonconst_outputs, &result->resource_updates)); + &num_nonconst_outputs, &result->outputs, &result->resource_updates)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; - result->outputs.resize(context->retvals().size()); - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (retval.has_constant_value()) { - OutputDescription& output = result->outputs[i]; - output.shape = retval.constant_value().shape(); - output.is_constant = true; - output.constant_value = retval.constant_value(); - } - } - // Compute the output shapes, if there is a computation with non-constant + // Compute the XLA output shape, if there is a computation with non-constant // outputs. - auto computation_shape = client()->GetComputationShape(*result->computation); - if (!computation_shape.ok()) { - return computation_shape.status(); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr computation_shape, + client()->GetComputationShape(*result->computation)); - result->xla_output_shape.Swap( - computation_shape.ValueOrDie()->mutable_result()); + result->xla_output_shape.Swap(computation_shape->mutable_result()); VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); @@ -724,23 +794,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); - // Converts the output shapes to TensorShapes. - int computation_output = 0; - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (!retval.has_constant_value()) { - TF_RET_CHECK(computation_output < num_computation_outputs) - << "Computation has more outputs than expected"; - OutputDescription& output = result->outputs[i]; - output.is_constant = false; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape( - xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape, - computation_output), - &output.shape)); - ++computation_output; - } - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index ca6cd822ef4effd48dbc3cc18d35d6642f303df1..c93850ce270502ea1df1f6469963e96e86994fa2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -38,7 +39,7 @@ class XlaContext; // It does a symbolic execution of the graph starting from specific input // shapes, using a JIT device to convert operators into XLA computations. // -// XlaCompiler is typically invoked from an `_XlaLaunch` operator once the +// XlaCompiler is typically invoked from an `XlaLaunch` operator once the // shapes of all input parameters to the computation are known. This is // because the symbolic execution requires known shapes for all operations. // @@ -67,6 +68,15 @@ class XlaContext; // _Retval values are ordered by _Retval index, whereas kResource values are // ordered by the original _Arg position of the variable. // +// If a shape representation function is provided as part of +// XlaCompiler::CompileOptions, kParameter arguments and return values to an +// entry computation will be reshaped in accordance to the shape function. +// Arguments and return values to a non-entry computation are not reshaped. +// Variable resource arguments are passed and returned in reshaped form, even +// for non-entry computations. This feature allows TensorFlow to keep on-device +// tensors with a different shape to their representation inside the XLA +// computation. +// // In both inputs and outputs, kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has // identical input and output signatures. By moving variable values @@ -171,7 +181,7 @@ class XlaCompiler { }; struct OutputDescription { - // Type and shape of the output. + // Type and shape of the output. The shape is the unflattened shape. DataType type; TensorShape shape; @@ -206,10 +216,12 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector input_mapping; - // Input shapes of the computation. + // Input shapes of the computation. If we are flattening inputs, these are + // the flattened shapes. std::vector xla_input_shapes; - // Output shape in XLA format. The output shape is always a tuple. + // Output shape in XLA format. The output shape is always a tuple. If we + // are flattening outputs, these are the flattened shapes. xla::Shape xla_output_shape; // TensorFlow shapes of outputs, together with the values of any @@ -230,10 +242,12 @@ class XlaCompiler { std::shared_ptr computation; }; + typedef std::function + ShapeRepresentationFn; struct Options { - // Name of the compilation device to use. Needs to be live only during - // XlaCompiler's constructor. - const DeviceType* device_type = nullptr; + // Name of the compilation device to use. It must be set by the caller. + // The default empty value is invalid. + DeviceType device_type = DeviceType(""); xla::Client* client = nullptr; @@ -250,8 +264,7 @@ class XlaCompiler { // If set, the XLA representation of variables represented to XLA as the // shape given by this shape function. Variables are reshaped to this shape // on write, and reshaped to their original shape on read. - std::function - variable_representation_shape_fn; + ShapeRepresentationFn shape_representation_fn; // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation @@ -300,7 +313,8 @@ class XlaCompiler { // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. - Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); + Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, + xla::Shape* xla_shape) const; // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 6b8918b26179735a4518a422fed024fa534122f5..613230452b74755ce7543ec2ab82861aa0dfeb7a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -25,13 +25,16 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -43,8 +46,6 @@ namespace tensorflow { class XlaCompilerTest : public ::testing::Test { protected: - XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} - void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); @@ -56,7 +57,7 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = &cpu_device_type_; + options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); options.client = client_; options.flib_def = flib_def_.get(); return options; @@ -66,7 +67,6 @@ class XlaCompilerTest : public ::testing::Test { return compiler->local_flib_def_.get(); } - DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -225,7 +225,7 @@ TEST_F(XlaCompilerTest, Simple) { xla::Literal::CreateR1({4, 143}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { @@ -320,7 +320,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE( + xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } { @@ -355,10 +356,80 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); } } +TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) { + // Define a function with one compile-time constant output and one + // data-dependent output. + // @function.Defun(noinline=True) + // foo(a) {b=7; return b, a; } + const Tensor seven = test::AsScalar(7); + FunctionDef fdef = FunctionDefHelper::Create( + "foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {}, + { + {{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}}, + }, + {{"a", "a_0"}, {"const", "Const:output:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib)); + auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0); + NodeDef foo; + foo.set_name("foo"); + foo.set_op("foo"); + *foo.add_input() = "input_arg"; + Status status; + scope.graph()->AddNode(foo, &status); + TF_ASSERT_OK(status); + NodeDef retval_1; + retval_1.set_name("retval_0"); + retval_1.set_op(FunctionLibraryDefinition::kRetOp); + *retval_1.add_input() = "foo"; + (*retval_1.mutable_attr())["T"].set_type(DT_INT32); + (*retval_1.mutable_attr())["index"].set_i(0); + scope.graph()->AddNode(retval_1, &status); + TF_ASSERT_OK(status); + NodeDef retval_2; + retval_2.set_name("retval_1"); + retval_2.set_op(FunctionLibraryDefinition::kRetOp); + *retval_2.add_input() = "foo:1"; + (*retval_2.mutable_attr())["T"].set_type(DT_INT32); + (*retval_2.mutable_attr())["index"].set_i(1); + scope.graph()->AddNode(retval_2, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({1}); + + XlaCompiler::Options options = DefaultOptions(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + options.flib_def = &flib_def; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.resolve_compile_time_constants = true; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", + std::move(graph), args, &result)); + + ASSERT_EQ(2, result.outputs.size()); + EXPECT_TRUE(result.outputs[0].is_constant); + test::ExpectTensorEqual(result.outputs[0].constant_value, + test::AsScalar(7)); + EXPECT_FALSE(result.outputs[1].is_constant); +} + // Tests compilation and execution of a graph that adds two tensors. TEST_F(XlaCompilerTest, ResourceManager) { // Builds a graph that calls the dummy resource Op. @@ -523,7 +594,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { {output_base.get(), output_grad1.get(), output_grad2.get()}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -746,13 +817,10 @@ TEST_F(XlaCompilerTest, Variables) { xla::Literal::CreateR1({4, 143}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } -// Tests a simple graph that reads and writes a variable, with a -// variable_representation_shape_fn passed to the compiler that flattens all -// variable tensors to vectors. -TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { +xla::StatusOr> BuildTestGraph() { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); @@ -763,7 +831,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_ASSERT_OK(scope.ToGraph(graph.get())); + TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); + return std::move(graph); +} + +// Tests a simple graph that reads and writes a variable, with a +// shape_representation_fn passed to the compiler that flattens all +// variable tensors to vectors. +TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); // Builds a description of the arguments. std::vector args(2); @@ -778,15 +854,33 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.variable_representation_shape_fn = [](const TensorShape& shape, - DataType type) { + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return TensorShape({shape.num_elements()}); }; XlaCompiler compiler(options); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; // Only reshape variables. + XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE( + xla::ShapeUtil::Compatible(program_shape->parameters(0), + xla::ShapeUtil::MakeShape(xla::S32, {2, 2}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. std::unique_ptr param0_literal = @@ -811,7 +905,186 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::Literal::CreateR1({26, 66, 34, 401}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); +} + +TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 2}); + + // Compiles the graph. + XlaCompiler::Options options = DefaultOptions(); + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { + return TensorShape({shape.num_elements()}); + }; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; // Reshape args and retvals. + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {4}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::Literal::CreateR1({4, 55, 1, -3}); + std::unique_ptr param1_literal = + xla::Literal::CreateR1({22, 11, 33, 404}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::Literal::CreateR1({27, 67, 35, 402}); + std::unique_ptr expected1 = + xla::Literal::CreateR1({26, 66, 34, 401}); + std::unique_ptr expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); +} + +// Tests a graph which has a function with an invalid op. +TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { + XlaCompiler compiler(DefaultOptions()); + + FunctionDefLibrary flib; + FunctionDef fn = FillFn(); + NodeDef* node = fn.add_node_def(); + node->set_name("Invalid"); + node->set_op("InvalidOp"); /* unsupported op */ + node = fn.add_node_def(); + node->set_name("Switch"); + node->set_op("Switch"); /* control flow node */ + *flib.add_function() = fn; + + TF_ASSERT_OK(flib_def_->AddFunctionDef(fn)); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib)); + + NodeDef def; + TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get()) + .Input(value.name(), 0, DT_INT32) + .Input(shape.name(), 1, DT_INT32) + .Finalize(&def)); + Status status; + Node* fill = scope.graph()->AddNode(def, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.DoShapeInference(fill)); + scope.graph()->AddEdge(value.node(), 0, fill, 0); + scope.graph()->AddEdge(shape.node(), 0, fill, 1); + + auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); + + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + std::vector args; + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}")) + << status.error_message(); +} + +// Tests a graph which has a node with invalid data type. +TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef shape; + shape.set_name("Shape"); + shape.set_op("Shape"); + (*shape.mutable_attr())["T"].set_type(DT_INT32); + (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */ + Status status; + Node* shape_node = graph->AddNode(shape, &status); + TF_ASSERT_OK(status); + graph->AddControlEdge(graph->source_node(), shape_node); + + std::vector args; + XlaCompiler::CompilationResult result; + XlaCompiler compiler(DefaultOptions()); + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "is not in the list of allowed values")) + << status.error_message(); +} + +TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef no_op; + no_op.set_name("NoOp"); + no_op.set_op("NoOp"); + Status status; + graph->AddNode(no_op, &status); + TF_ASSERT_OK(status); + + std::vector args; + XlaCompiler compiler(DefaultOptions()); + // No control edge linking NoOp with source/sink. + { + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: NoOp")) + << status.error_message(); + } + + // Fix control edges for NoOp. + { + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result)); + EXPECT_EQ(0, result.resource_updates.size()); + } } } // namespace diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 3dd2d183f3a538786856dd8d92c5886b1cc237d8..098072d33cd4eb7f7dec0ec4196b43eca0220d4a 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -65,26 +65,30 @@ void XlaContext::set_args(std::vector args) { XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn) + shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants), - variable_representation_shape_fn_(variable_representation_shape_fn) {} + is_entry_computation_(is_entry_computation), + shape_representation_fn_(shape_representation_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. void XlaContext::AddRetval(int retval_index, DataType type, - const xla::XlaOp& handle) { + const TensorShape& shape, const xla::XlaOp& handle) { VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; // Add the return value to the list being built up. if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - retvals_[retval_index].set_handle(handle); + XlaExpression e; + e.set_handle(handle); + retvals_[retval_index] = Retval{type, shape, e}; } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, @@ -94,13 +98,11 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - if (resolve_compile_time_constants_) { - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - retvals_[retval_index].set_constant_value(std::move(value)); - } else { - retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal)); - } + Tensor value; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); + XlaExpression e; + e.set_constant_value(value); + retvals_[retval_index] = Retval{dtype, value.shape(), e}; return Status::OK(); } @@ -117,9 +119,9 @@ Status XlaContext::CreateResource( return Status::OK(); } -TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, - DataType type) const { - return (*variable_representation_shape_fn_)(shape, type); +TensorShape XlaContext::RepresentationShape(const TensorShape& shape, + DataType type) const { + return (*shape_representation_fn_)(shape, type); } const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 1136ffe5073a8e7fd3c27d6ec7050cb1f8307584..341bf6ff1f37fa7cd81f41c02a941214067b1bd1 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -42,11 +42,13 @@ class XlaContext : public ResourceBase { static XlaContext& Get(const OpKernelContext* ctx); static XlaContext& Get(const XlaOpKernelContext* ctx); - // Creates a new XlaContext. + // Creates a new XlaContext. See the documentation on the class data fields + // for descriptions of the arguments. XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn); + shape_representation_fn); // Virtual method defined by ResourceBase. string DebugString() override; @@ -58,14 +60,26 @@ class XlaContext : public ResourceBase { bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } + bool resolve_compile_time_constants() const { + return resolve_compile_time_constants_; + } + bool is_entry_computation() const { return is_entry_computation_; } + const std::vector& args() const { return args_; } void set_args(std::vector args); - const std::vector& retvals() { return retvals_; } + struct Retval { + DataType type; + TensorShape shape; + // An XlaExpression representing the Retval's value. + XlaExpression expression; + }; + const std::vector& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, const xla::XlaOp& handle); + void AddRetval(int retval_index, DataType type, const TensorShape& shape, + const xla::XlaOp& handle); // As for Retval, but for return values that are compile-time constants. Status AddConstRetval(int retval_index, DataType dtype, @@ -86,9 +100,9 @@ class XlaContext : public ResourceBase { } // Returns the XLA shape to be used to represent a variable of TF `shape` - // and `type`. - TensorShape VariableRepresentationShape(const TensorShape& shape, - DataType type) const; + // and `type`, or of an argument or return value of a top-level computation. + TensorShape RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -131,15 +145,23 @@ class XlaContext : public ResourceBase { std::vector args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector retvals_; + std::vector retvals_; // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // A function that describes how variable shapes should be represented - // in XLA. Variable values will be reshaped to this shape. Must be non-null. + // Is this a top-level computation, or an inner computation (e.g., a while + // body)? + const bool is_entry_computation_; + + // A function that describes how the shapes of + // a) argument and return value, for entry computations + // b) variables, for all computations, + // should be represented in XLA. Parameters/return values will be shaped + // according to this function, and reshaped back to/from their declared shapes + // for computations. Must be non-null. const std::function* - variable_representation_shape_fn_; + shape_representation_fn_; // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f1594193af09c7193f03b4685d3a7d4510d654dd..a1da176fe30ddd0d4460a51b60b2568ecc1af6aa 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -19,11 +19,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -210,8 +212,9 @@ Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, return errors::InvalidArgument("Invalid argument type ", DataTypeString(dtype)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); + *iota = builder->ConstantLiteral(linspace_literal); return Status::OK(); } @@ -245,8 +248,8 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, return errors::InvalidArgument("Invalid argument type ", DataTypeString(index_type)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 2b65f4d5d5936e062e5351a0723544191ffe2dfa..76c68d81af4dd9ec40fe6b1c33b03a876a0c6dc6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -314,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, } XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = xla_context.VariableRepresentationShape( - variable->shape(), variable->type()); + TensorShape representation_shape = + xla_context.RepresentationShape(variable->shape(), variable->type()); if (representation_shape == variable->shape()) { *value = variable->value(); } else { @@ -436,7 +436,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, XlaContext& xla_context = XlaContext::Get(context_); TensorShape representation_shape = - xla_context.VariableRepresentationShape(shape, type); + xla_context.RepresentationShape(shape, type); if (shape != representation_shape) { handle = builder()->Reshape(handle, representation_shape.dim_sizes()); } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index e309cb1e34db7f8430c2494c03aed41652b7a167..4692038b61f6871a8a16299fd4d11e963eb46a57 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -39,10 +39,10 @@ const char* const DEVICE_XLA_GPU = "XLA_GPU"; static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { const OpDef* op_def; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def)); + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def)); NodeDef node_def; node_def.set_name("_XlaLaunch-op"); - node_def.set_op("_XlaLaunch"); + node_def.set_op("XlaLaunch"); string kernel_class_name; TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, &kernel_class_name)); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index dbf14f32bc3e54a9b4f0e1fbc5d827e8708b73f7..1b8e516770c3e217dd7c2f26ce426895b478c2e4 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -53,7 +53,6 @@ xla_proto_library( deps = [ ":xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:session_proto", ], ) @@ -99,9 +98,9 @@ cc_library( hdrs = ["service_interface.h"], visibility = [":friends"], deps = [ + ":status", ":xla_data_proto", ":xla_proto", - "//tensorflow/core:lib", ], ) @@ -245,6 +244,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protobuf_util", + ":status", ":status_macros", ":statusor", ":types", @@ -303,13 +303,13 @@ cc_library( ":array2d", ":array3d", ":array4d", - ":shape_tree", ":shape_util", ":sparse_index_array", ":status_macros", ":types", ":util", ":xla_data_proto", + "//tensorflow/core:framework", "//tensorflow/core:lib", ], ) @@ -324,12 +324,30 @@ tf_cc_test( ":shape_util", ":test", ":types", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) +cc_library( + name = "error_spec", + hdrs = ["error_spec.h"], +) + +cc_library( + name = "literal_comparison", + srcs = ["literal_comparison.cc"], + hdrs = ["literal_comparison.h"], + deps = [ + ":error_spec", + ":literal_util", + ":util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "metric_table_report", srcs = ["metric_table_report.cc"], @@ -564,6 +582,7 @@ tf_cc_test( ":shape_util", ":test", ":xla_data_proto", + "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index aac3273d5fd144f3b737529b0833c9328b3d0e4d..8f08d3b2e04670ad6590aca1db0fd9d25faed83f 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -63,7 +63,6 @@ cc_library( srcs = ["client.cc"], hdrs = ["client.h"], deps = [ - ":computation", ":global_data", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", @@ -76,7 +75,7 @@ cc_library( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -87,6 +86,7 @@ cc_library( hdrs = ["executable_build_options.h"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", @@ -99,7 +99,6 @@ cc_library( hdrs = ["local_client.h"], deps = [ ":client", - ":computation", ":executable_build_options", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", @@ -111,6 +110,7 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", @@ -126,7 +126,6 @@ cc_library( hdrs = ["compile_only_client.h"], deps = [ ":client", - ":computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -162,47 +161,6 @@ cc_library( ], ) -cc_library( - name = "computation", - srcs = ["computation.cc"], - hdrs = ["computation.h"], - deps = [ - "//tensorflow/compiler/xla:service_interface", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:session_proto", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "computation_builder", - srcs = ["computation_builder.cc"], - hdrs = ["computation_builder.h"], - deps = [ - ":client", - ":computation", - ":global_data", - ":padding", - "//tensorflow/compiler/xla:array", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "sharding_builder", srcs = ["sharding_builder.cc"], diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 328e1b8fa84e7baaca41c6c9a65e9a1598ac32ae..3d596a6e65430b6e9692aabd65fc8aa84b7b873d 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -64,7 +64,7 @@ StatusOr> Client::Transfer( } StatusOr> Client::TransferToServer( - const Literal& literal, const DeviceHandle* device_handle) { + const LiteralSlice& literal, const DeviceHandle* device_handle) { TransferToServerRequest request; *request.mutable_literal() = literal.ToProto(); if (device_handle) { @@ -91,7 +91,7 @@ StatusOr> Client::TransferToServer( return MakeUnique(stub_, response.data()); } -Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, +Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, const DeviceHandle* device_handle) { TransferToInfeedRequest request; *request.mutable_literal() = literal.ToProto(); @@ -161,22 +161,6 @@ Status Client::ResetDevice() { return Status::OK(); } -StatusOr> Client::ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr data, - Execute(computation, arguments, execution_options, execution_profile)); - - const Shape* shape_with_output_layout = nullptr; - if (execution_options && execution_options->has_shape_with_output_layout()) { - shape_with_output_layout = &execution_options->shape_with_output_layout(); - } - return Transfer(*data, shape_with_output_layout); -} - StatusOr> Client::ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -221,65 +205,11 @@ StatusOr> Client::ComputeConstant( return Literal::CreateFromProto(response.literal()); } -StatusOr Client::LoadSnapshot(const SessionModule& module) { - LoadComputationSnapshotRequest request; - *request.mutable_module() = module; - LoadComputationSnapshotResponse response; - - Status s = stub_->LoadComputationSnapshot(&request, &response); - if (!s.ok()) { - return s; - } - - VLOG(1) << "load snapshot response: " << response.ShortDebugString(); - return Computation(stub_, response.computation()); -} - StatusOr Client::LoadSnapshot(const HloSnapshot& module) { TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module()); return XlaComputation(module.hlo().hlo_module()); } -StatusOr> Client::Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - ExecuteRequest request; - *request.mutable_computation() = computation.handle(); - - if (execution_options == nullptr) { - *request.mutable_execution_options() = CreateDefaultExecutionOptions(); - } else { - *request.mutable_execution_options() = *execution_options; - } - for (GlobalData* argument : arguments) { - CHECK(argument != nullptr) << "Argument pointers must not be null."; - *request.add_arguments() = argument->handle(); - } - - ExecuteResponse response; - VLOG(1) << "making execute request: " << request.ShortDebugString(); - Status s = stub_->Execute(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - if (execution_profile != nullptr) { - *execution_profile = response.profile(); - if (VLOG_IS_ON(1)) { - TF_ASSIGN_OR_RETURN( - auto execution_stats, - ExecutionStatsAsString(computation, response.profile())); - VLOG(1) << execution_stats; - } - } - - return MakeUnique(stub_, response.output()); -} - StatusOr> Client::Execute( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -320,41 +250,6 @@ StatusOr> Client::Execute( return MakeUnique(stub_, response.output()); } -StatusOr>> Client::ExecuteParallel( - tensorflow::gtl::ArraySlice computations) { - ExecuteParallelRequest request; - - for (const ComputationInstance& computation : computations) { - ExecuteRequest single_request; - *single_request.mutable_computation() = computation.computation.handle(); - for (GlobalData* argument : computation.arguments) { - *single_request.add_arguments() = argument->handle(); - } - *single_request.mutable_execution_options() = computation.execution_options; - *request.add_requests() = single_request; - } - - ExecuteParallelResponse response; - VLOG(1) << "making execute-parallel request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->ExecuteParallel(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> outputs; - for (size_t i = 0; i < computations.size(); ++i) { - outputs.push_back( - MakeUnique(stub_, response.responses(i).output())); - if (computations[i].execution_profile != nullptr) { - *computations[i].execution_profile = response.responses(i).profile(); - } - } - - return std::move(outputs); -} - StatusOr>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice computations) { ExecuteGraphParallelRequest request; @@ -372,7 +267,7 @@ StatusOr>> Client::ExecuteParallel( ExecuteParallelResponse response; VLOG(1) << "making execute-graph-parallel request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->ExecuteGraphParallel(&request, &response); + Status s = stub_->ExecuteGraphParallel(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -401,7 +296,7 @@ StatusOr> Client::GetDeviceHandles( GetDeviceHandlesResponse response; VLOG(1) << "making get device request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->GetDeviceHandles(&request, &response); + Status s = stub_->GetDeviceHandles(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -449,24 +344,6 @@ StatusOr>> Client::DeconstructTuple( return std::move(handles); } -StatusOr Client::GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const { - ComputationStatsRequest request; - *request.mutable_computation() = computation.handle(); - *request.mutable_debug_options() = debug_options; - ComputationStatsResponse response; - - VLOG(1) << "making computation stats request"; - Status s = stub_->GetComputationStats(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - CHECK(response.has_stats()); - return response.stats(); -} - StatusOr Client::GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const { @@ -488,23 +365,6 @@ StatusOr Client::GetComputationStats( return response.stats(); } -StatusOr> Client::GetComputationShape( - const Computation& computation) { - GetComputationShapeRequest request; - *request.mutable_computation() = computation.handle(); - GetComputationShapeResponse response; - - VLOG(1) << "making get-computation-shape request"; - Status s = stub_->GetComputationShape(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return WrapUnique(response.release_program_shape()); -} - StatusOr> Client::GetComputationShape( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); @@ -527,28 +387,6 @@ StatusOr Client::GetShape(const GlobalData& data) { return response.shape(); } -StatusOr Client::ExecutionStatsAsString( - const Computation& computation, const ExecutionProfile& profile) { - TF_ASSIGN_OR_RETURN( - auto computation_stats, - GetComputationStats(computation, - legacy_flags::GetDebugOptionsFromFlags())); - int64 total_flops = - computation_stats.flop_count() + computation_stats.transcendental_count(); - if (profile.compute_time_ns() > 0) { - int64 nanoseconds = profile.compute_time_ns(); - int64 cycle_count = profile.compute_cycle_count(); - double gflops = total_flops / nanoseconds; - return tensorflow::strings::StrCat( - "[Execution Statistics] flop count: ", computation_stats.flop_count(), - ", transcendental count: ", computation_stats.transcendental_count(), - ", compute execution time: ", nanoseconds, " nsec", - ", compute cycles: ", cycle_count, ", performance: ", gflops, - "gflop/s"); - } - return string("[Execution Statistics] not available."); -} - StatusOr Client::ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index a63ff4c56d1dd78c7abfa2bf163b5fbd54d82b2b..68f0d0ac78c859fde7a6a007cd250b047a7bfcda 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -19,11 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -52,21 +51,6 @@ class Client { // device is chosen by the service. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. - StatusOr> Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and returns the global - // data that was produced from the execution. - // * If execution_options is not nullptr, these options are passed to the - // service to affect how it compiles our computation. (The pointer does not - // need to live beyond this call.) - // * If execution_profile is not nullptr then the pointed-to ExecutionProfile - // will be filled with profile data from the execution. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> Execute( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -78,34 +62,6 @@ class Client { // executed on the devices associated with the handles by partitioning the // computation based on the attached sharding attributes. Otherwise, a // device is chosen by the service. - struct ComputationInstance { - const Computation& computation; - std::vector arguments; - ExecutionOptions execution_options; - ExecutionProfile* execution_profile; - - ComputationInstance(const Computation& computation, - std::vector arguments, - ExecutionOptions execution_options, - ExecutionProfile* execution_profile) - : computation(computation), - arguments(std::move(arguments)), - execution_options(execution_options), - execution_profile(execution_profile) {} - }; - - // Executes a list ComputationInstances and returns global data produced from - // each computation. - StatusOr>> ExecuteParallel( - tensorflow::gtl::ArraySlice computations); - - // A struct to represent a computation instance to be executed. - // * If execution_options.device_handles is not empty, the computation is - // executed on the devices associated with the handles by partitioning the - // computation based on the attached sharding attributes. Otherwise, a - // device is chosen by the service. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct XlaComputationInstance { const XlaComputation& computation; std::vector arguments; @@ -125,7 +81,6 @@ class Client { // Executes a list XlaComputationInstances and returns global data produced // from each computation. // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> ExecuteParallel( tensorflow::gtl::ArraySlice computations); @@ -152,14 +107,14 @@ class Client { // device (and its replicas if replication is enabled). Otherwise, data is // transferred to the default device (and its replicas). StatusOr> TransferToServer( - const Literal& literal, const DeviceHandle* device_handle = nullptr); + const LiteralSlice& literal, const DeviceHandle* device_handle = nullptr); // Transfer the given literal to the Infeed interface of the device. // // device_handle and replica_id together specify a particular device; a device // assigned for the given replica_id among the replicas that the given device // handle belongs to. - Status TransferToInfeed(const Literal& literal, int64 replica_id = 0, + Status TransferToInfeed(const LiteralSlice& literal, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); // Transfers from the Outfeed of the device. @@ -177,17 +132,6 @@ class Client { // Executes the computation with the given arguments and transfers the result // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). - StatusOr> ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and transfers the result - // to the client as a literal. Parameters are defined the same as for - // Execute() and Transfer(). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -209,8 +153,6 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; @@ -223,12 +165,6 @@ class Client { const GlobalData& data); // Retrieves the statistics of the given computation. - StatusOr GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const; - - // Retrieves the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const; @@ -239,13 +175,6 @@ class Client { // As above, but returns the shape of the provided computation (parameter // types/names and return type). - StatusOr> GetComputationShape( - const Computation& computation); - - // As above, but returns the shape of the provided computation (parameter - // types/names and return type). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> GetComputationShape( const XlaComputation& computation); @@ -253,9 +182,6 @@ class Client { // two computations via a pair of Send and Recv instructions. StatusOr CreateChannelHandle(); - StatusOr LoadSnapshot(const SessionModule& module); - - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr LoadSnapshot(const HloSnapshot& module); ServiceInterface* stub() { return stub_; } @@ -263,8 +189,6 @@ class Client { private: // Returns the execution statistics (e.g., gflop/s) as a string from the // ExecutionProfile returned from an execution of the computation. - StatusOr ExecutionStatsAsString(const Computation& computation, - const ExecutionProfile& profile); StatusOr ExecutionStatsAsString(const XlaComputation& computation, const ExecutionProfile& profile); diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 96e38bca01087991943aff40ed1cb3e21f9e6cba..dc69d2097ebe14ca0e14a39849d4fcae99024fdc 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -21,24 +21,6 @@ limitations under the License. namespace xla { -StatusOr>> -CompileOnlyClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector service_instances; - service_instances.reserve(computations.size()); - for (const AotComputationInstance& instance : computations) { - service_instances.push_back({}); - CompileOnlyService::AotComputationInstance& service_instance = - service_instances.back(); - TF_RET_CHECK(instance.computation != nullptr); - service_instance.computation = instance.computation->handle(); - service_instance.argument_layouts = instance.argument_layouts; - service_instance.result_layout = instance.result_layout; - } - return compiler_service_->CompileAheadOfTime(service_instances, options); -} - StatusOr>> CompileOnlyClient::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index c8725b8517484acdaf093bc3b34adb00f69155b1..f9a7c31270c7a11175f47a537639a97d0c9211af 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -38,26 +37,7 @@ class CompileOnlyClient : public Client { CompileOnlyClient(const CompileOnlyClient&) = delete; void operator=(const CompileOnlyClient&) = delete; - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - const Computation* computation; - // Inform the compiler of the expected layout for arguments. - std::vector argument_layouts; - // Specifies the expected result layout. - const Shape* result_layout; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options); - // A description of an xla computation to compile using CompileAheadOfTime. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct AotXlaComputationInstance { const XlaComputation* computation; // Inform the compiler of the expected layout for arguments. @@ -69,8 +49,6 @@ class CompileOnlyClient : public Client { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. The |options| parameter describes // the target for which the compiler should emit code. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc deleted file mode 100644 index e6c57bda0f0c4cb969939883efebcf3a6d6be381..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation.h" - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { - -Computation::Computation() : parent_(nullptr) {} - -Computation::Computation(ServiceInterface* parent, - const ComputationHandle& handle) - : handle_(handle), parent_(parent) {} - -Computation::Computation(Computation&& computation) - : handle_(std::move(computation.handle_)), parent_(computation.parent_) { - computation.ResetWithoutFreeing(); -} - -void Computation::Reset() { - // TODO(b/34469253) deallocate any owned computation. - ResetWithoutFreeing(); -} - -StatusOr> Computation::Snapshot() const { - SnapshotComputationRequest request; - *request.mutable_computation() = handle_; - SnapshotComputationResponse response; - - TF_RETURN_IF_ERROR(parent_->SnapshotComputation(&request, &response)); - - return WrapUnique(response.release_module()); -} - -Computation::~Computation() { Reset(); } - -Computation& Computation::operator=(Computation&& computation) { - if (&computation != this) { - Reset(); - handle_ = computation.handle_; - parent_ = computation.parent_; - computation.ResetWithoutFreeing(); - } - return *this; -} - -void Computation::ResetWithoutFreeing() { - handle_.Clear(); - parent_ = nullptr; -} - -StatusOr Computation::GetProgramShape() const { - GetComputationShapeRequest request; - *request.mutable_computation() = handle_; - GetComputationShapeResponse response; - - TF_RETURN_IF_ERROR(parent_->GetComputationShape(&request, &response)); - - return std::move(*response.mutable_program_shape()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h deleted file mode 100644 index 9a1bcde76387297cb7f374b25baad1d5ec284859..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation.h +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ - -#include - -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service_interface.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" - -namespace xla { - -// Wraps a ComputationHandle protobuf with a lifetime. Computation is -// movable and not copyable to capture the same kind of unique -// ownership that std::unique_ptr represents. -// -// TODO(b/74197823): Deprecated. Use XlaComputation instead. -class Computation { - public: - // Creates a null Computation. - Computation(); - - // parent: stub for the service on which we will deallocate the computation - // when it is no longer needed. - // handle: the computation handle protobuf from the service. - Computation(ServiceInterface* parent, const ComputationHandle& handle); - - Computation(Computation&& computation); - - // Deallocates the computation. - ~Computation(); - - Computation& operator=(Computation&& computation); - - // Returns the underlying handle. - const ComputationHandle& handle() const { return handle_; } - - // Sets handle to a null state and clears any owned computation. - void Reset(); - - // Requests that we snapshot the computation into a serializable protocol - // buffer form. - StatusOr> Snapshot() const; - - // Returns true if this object is a null Computation. - bool IsNull() const { return parent_ == nullptr; } - - // Returns the "program shape" (parameter and return shapes) for this - // computation. - StatusOr GetProgramShape() const; - - private: - void ResetWithoutFreeing(); - - ComputationHandle handle_; // Handle that is wrapped by this class. - - // Stub that the handle is deallocated on when this object's lifetime ends. - ServiceInterface* parent_; - - TF_DISALLOW_COPY_AND_ASSIGN(Computation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc deleted file mode 100644 index 83c7cb174402133706fbde6a734a29afd8edfe80..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ /dev/null @@ -1,1574 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation_builder.h" - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace xla { - -ComputationBuilder::ComputationBuilder(Client* client, - const string& computation_name) - : name_(computation_name), client_(client) {} - -ComputationBuilder::~ComputationBuilder() {} - -void ComputationBuilder::NoteError(const Status& error) { - if (die_immediately_on_error_) { - LOG(FATAL) << "error building computation: " << error; - } - - if (first_error_.ok()) { - first_error_ = error; - first_error_backtrace_.CreateCurrent(/*skip_count=*/1); - } -} - -std::unique_ptr ComputationBuilder::CreateSubBuilder( - const string& computation_name) { - auto sub_builder = MakeUnique(client_, computation_name); - sub_builder->parent_builder_ = this; - sub_builder->die_immediately_on_error_ = die_immediately_on_error_; - return sub_builder; -} - -Status ComputationBuilder::PrepareComputation() { - TF_RETURN_IF_ERROR(first_error_); - - if (!computation_.IsNull()) { - return Status::OK(); - } - - ComputationRequest request; - request.set_name(name_); - ComputationResponse response; - - VLOG(2) << "making computation request"; - Status s = client_->stub()->Computation(&request, &response); - VLOG(2) << "done with computation request"; - - if (!s.ok()) { - NoteError(s); - return first_error_; - } - - computation_ = Computation(client_->stub(), response.computation()); - return Status::OK(); -} - -Status ComputationBuilder::RunOp(OpRequest* op_request, - OpResponse* op_response) { - TF_RETURN_IF_ERROR(first_error_); - TF_RETURN_IF_ERROR(PrepareComputation()); - - // Fill in fields that are set on every OpRequest. - *op_request->mutable_computation() = computation_.handle(); - *op_request->mutable_metadata() = metadata_; - if (sharding_) { - *op_request->mutable_sharding() = *sharding_; - } - - const string& op_name = - OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name(); - VLOG(2) << "running op request: " << op_name; - Status status = client_->stub()->Op(op_request, op_response); - VLOG(2) << "done with op request: " << op_name; - return status; -} - -void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) { - OpResponse op_response; - Status status = RunOp(op_request, &op_response); - if (!status.ok()) { - NoteError(status); - } -} - -ComputationDataHandle ComputationBuilder::RunOpAndParseResponse( - OpRequest* op_request) { - OpResponse op_response; - Status status = RunOp(op_request, &op_response); - if (!status.ok()) { - NoteError(status); - return ComputationDataHandle(); - } - if (op_response.output().handle() == 0) { - NoteError(InternalError("No output handle")); - return ComputationDataHandle(); - } - return op_response.output(); -} - -bool ComputationBuilder::MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, Window* window) { - const auto verify_size = [&](const size_t x, const char* x_name) { - if (x == 0 || x == window_dimensions.size()) { - return true; - } else { - NoteError(InvalidArgument( - "%s", tensorflow::strings::StrCat( - "Window has different number of window dimensions than of ", - x_name, "\nNumber of window dimensions: ", - window_dimensions.size(), "\nNumber of ", x_name, ": ", x, - "\n") - .c_str())); // - return false; - } - }; - if (!verify_size(window_strides.size(), "window strides") || - !verify_size(padding.size(), "padding entries") || - !verify_size(lhs_dilation.size(), "lhs dilation factors") || - !verify_size(rhs_dilation.size(), "rhs dilation factors")) { - return false; - } - - window->Clear(); - for (size_t i = 0; i < window_dimensions.size(); i++) { - auto dim = window->add_dimensions(); - dim->set_size(window_dimensions[i]); - if (!window_strides.empty()) { - dim->set_stride(window_strides[i]); - } else { - dim->set_stride(1); - } - if (!padding.empty()) { - dim->set_padding_low(padding[i].first); - dim->set_padding_high(padding[i].second); - } else { - dim->set_padding_low(0); - dim->set_padding_high(0); - } - if (!lhs_dilation.empty()) { - dim->set_base_dilation(lhs_dilation[i]); - } else { - dim->set_base_dilation(1); - } - if (!rhs_dilation.empty()) { - dim->set_window_dilation(rhs_dilation[i]); - } else { - dim->set_window_dilation(1); - } - dim->set_window_reversal(false); - } - return true; -} - -ComputationDataHandle ComputationBuilder::ConstantLiteral( - const Literal& literal) { - OpRequest op_request; - ConstantRequest* request = op_request.mutable_constant_request(); - *request->mutable_literal() = literal.ToProto(); - VLOG(3) << "created constant: " << request->literal().ShortDebugString(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { - OpRequest op_request; - ParameterRequest* request = op_request.mutable_parameter_request(); - *request->mutable_shape() = shape; - request->set_parameter(parameter_number); - request->set_name(name); - return RunOpAndParseResponse(&op_request); -} - -StatusOr> ComputationBuilder::GetShapeWithoutNoteError( - const ComputationDataHandle& operand) { - GetLocalShapeRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - GetLocalShapeResponse response; - - VLOG(2) << "making get-shape request"; - TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response)); - VLOG(2) << "done with request"; - - TF_RET_CHECK(response.has_shape()); - std::unique_ptr shape = WrapUnique(response.release_shape()); - TF_RET_CHECK(shape != nullptr); - return std::move(shape); -} - -StatusOr> ComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - TF_RETURN_IF_ERROR(first_error_); - - auto status_or_shape = GetShapeWithoutNoteError(operand); - if (!status_or_shape.ok()) { - NoteError(status_or_shape.status()); - return first_error_; - } - return status_or_shape; -} - -StatusOr ComputationBuilder::GetProgramShape() { - TF_RETURN_IF_ERROR(first_error_); - - GetComputationShapeRequest request; - *request.mutable_computation() = computation_.handle(); - GetComputationShapeResponse response; - - VLOG(2) << "making get-program-shape-request"; - Status status = client_->stub()->GetComputationShape(&request, &response); - VLOG(2) << "done with get-program-shape-request"; - - if (!status.ok()) { - first_error_ = status; - return status; - } - - TF_RET_CHECK(response.has_program_shape()); - return std::move(*response.mutable_program_shape()); -} - -ComputationDataHandle ComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { - OpRequest op_request; - SliceRequest* request = op_request.mutable_slice_request(); - *request->mutable_operand() = operand; - for (int64 index : start_indices) { - request->add_start_indices(index); - } - for (int64 index : limit_indices) { - request->add_limit_indices(index); - } - for (int64 index : strides) { - request->add_strides(index); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - NoteError(shape_status.status()); - return ComputationDataHandle{}; - } - const Shape& shape = *shape_status.ValueOrDie(); - std::vector starts(ShapeUtil::Rank(shape), 0); - std::vector limits(shape.dimensions().begin(), - shape.dimensions().end()); - std::vector strides(ShapeUtil::Rank(shape), 1); - starts[dimno] = start_index; - limits[dimno] = limit_index; - strides[dimno] = stride; - return Slice(operand, starts, limits, strides); -} - -ComputationDataHandle ComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { - OpRequest op_request; - DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request(); - *request->mutable_operand() = operand; - *request->mutable_start_indices() = start_indices; - for (int64 index : slice_sizes) { - request->add_slice_sizes(index); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - OpRequest op_request; - DynamicUpdateSliceRequest* request = - op_request.mutable_dynamic_update_slice_request(); - *request->mutable_operand() = operand; - *request->mutable_update() = update; - *request->mutable_start_indices() = start_indices; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - OpRequest op_request; - ConcatenateRequest* request = op_request.mutable_concatenate_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - request->set_dimension(dimension); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { - OpRequest op_request; - BroadcastRequest* request = op_request.mutable_broadcast_request(); - *request->mutable_operand() = operand; - for (int64 size : broadcast_sizes) { - request->add_broadcast_sizes(size); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - OpRequest op_request; - PadRequest* request = op_request.mutable_pad_request(); - *request->mutable_operand() = operand; - *request->mutable_padding_value() = padding_value; - *request->mutable_padding_config() = padding_config; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { - OpRequest op_request; - ReshapeRequest* request = op_request.mutable_reshape_request(); - *request->mutable_operand() = operand; - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - for (int64 new_size : new_sizes) { - request->add_new_sizes(new_size); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice new_sizes) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - std::vector dimensions(shape.ValueOrDie()->dimensions().size()); - std::iota(dimensions.begin(), dimensions.end(), 0); - return Reshape(operand, dimensions, new_sizes); -} - -ComputationDataHandle ComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - // Don't support out-of-order collapse here. - // Checks that the collapsed dimensions are in order and consecutive. - for (tensorflow::gtl::ArraySlice::size_type i = 1; - i < dimensions.size(); ++i) { - if (dimensions[i] - 1 != dimensions[i - 1]) { - NoteError(InvalidArgument( - "Collapsed dimensions are not in order and consecutive.")); - return ComputationDataHandle(); - } - } - - // Create a new sizes vector from the old shape, replacing the collapsed - // dimensions by the product of their sizes. - StatusOr> shape_or_status = GetShape(operand); - if (!shape_or_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original_shape = shape_or_status.ConsumeValueOrDie(); - - VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); - VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dimensions, ","); - - if (dimensions.size() <= 1) { - // Not collapsing anything, trivially we can return the operand versus - // enqueueing a trivial reshape. - return operand; - } - - std::vector new_sizes; - for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { - if (i <= dimensions.front() || i > dimensions.back()) { - new_sizes.push_back(original_shape->dimensions(i)); - } else { - new_sizes.back() *= original_shape->dimensions(i); - } - } - - VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") - << "]"; - - return Reshape(operand, new_sizes); -} - -void ComputationBuilder::Trace(const string& tag, - const ComputationDataHandle& operand) { - OpRequest op_request; - TraceRequest* request = op_request.mutable_trace_request(); - request->set_tag(tag); - *request->mutable_operand() = operand; - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Select( - const ComputationDataHandle& pred, const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false) { - return TernaryOp(TRIOP_SELECT, pred, on_true, on_false); -} - -ComputationDataHandle ComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - OpRequest op_request; - VariadicOpRequest* request = op_request.mutable_variadic_op_request(); - request->set_varop(VAROP_TUPLE); - for (const ComputationDataHandle& operand : elements) { - *request->add_operands() = operand; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - OpRequest op_request; - GetTupleElementRequest* request = - op_request.mutable_get_tuple_element_request(); - *request->mutable_operand() = tuple_data; - request->set_index(index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Eq( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Ne( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Ge( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Gt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Le( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Lt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - - DotDimensionNumbers dimension_numbers; - dimension_numbers.add_lhs_contracting_dimensions( - lhs_shape->dimensions_size() == 1 ? 0 : 1); - dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers) { - OpRequest op_request; - DotRequest* request = op_request.mutable_dot_request(); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_dimension_numbers() = dimension_numbers; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Conv( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - return ConvWithGeneralDimensions( - lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); -} - -ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { - return ConvGeneral(lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); -} - -bool ComputationBuilder::VerifyConvolution( - const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { - NoteError( - InvalidArgument("Convolution arguments must have same number of " - "dimensions. Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str())); - return false; - } - int num_dims = ShapeUtil::Rank(lhs_shape); - if (num_dims < 2) { - NoteError(InvalidArgument( - "Convolution expects argument arrays with >= 3 dimensions. " - "Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str())); - return false; - } - int num_spatial_dims = num_dims - 2; - - const auto check_spatial_dimensions = - [&](const char* const field_name, - const tensorflow::protobuf::RepeatedField& - numbers) { - if (numbers.size() != num_spatial_dims) { - NoteError(InvalidArgument("Expected %d elements for %s, but got %d.", - num_spatial_dims, field_name, - numbers.size())); - return false; - } - for (int i = 0; i < numbers.size(); ++i) { - if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { - NoteError( - InvalidArgument("Convolution %s[%d] is out of bounds: %lld", - field_name, i, numbers.Get(i))); - return false; - } - } - return true; - }; - return check_spatial_dimensions( - "input_spatial_dimensions", - dimension_numbers.input_spatial_dimensions()) && - check_spatial_dimensions( - "kernel_spatial_dimensions", - dimension_numbers.kernel_spatial_dimensions()) && - check_spatial_dimensions( - "output_spatial_dimensions", - dimension_numbers.output_spatial_dimensions()); -} - -ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - StatusOr> rhs_shape_or_status = GetShape(rhs); - if (!rhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); - - if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { - NoteError(InternalError("failed to verify convolution")); - return ComputationDataHandle(); - } - - std::vector base_area_dimensions( - dimension_numbers.input_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < base_area_dimensions.size(); - ++i) { - base_area_dimensions[i] = - lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i)); - } - - std::vector window_dimensions( - dimension_numbers.kernel_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { - window_dimensions[i] = - rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); - } - - return ConvGeneral(lhs, rhs, window_strides, - MakePadding(base_area_dimensions, window_dimensions, - window_strides, padding), - dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::ConvGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, - dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - StatusOr> rhs_shape_or_status = GetShape(rhs); - if (!rhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); - if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { - // Error is recorded in VerifyConvolution. - return ComputationDataHandle(); - } - - std::vector window_dimensions( - dimension_numbers.kernel_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { - window_dimensions[i] = - rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); - } - - OpRequest op_request; - ConvolveRequest* request = op_request.mutable_convolve_request(); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_dimension_numbers() = dimension_numbers; - - if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, - rhs_dilation, request->mutable_window())) { - // Error is recorded in MakeWindow. - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Fft( - const ComputationDataHandle& operand, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { - OpRequest op_request; - FftRequest* request = op_request.mutable_fft_request(); - *request->mutable_operand() = operand; - request->set_fft_type(fft_type); - for (int64 dim_len : fft_length) { - request->add_fft_length(dim_len); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, - const string& config) { - OpRequest op_request; - InfeedRequest* request = op_request.mutable_infeed_request(); - *request->mutable_shape() = shape; - *request->mutable_config() = config; - return RunOpAndParseResponse(&op_request); -} - -void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, - const Shape& shape_with_layout, - const string& outfeed_config) { - OpRequest op_request; - OutfeedRequest* request = op_request.mutable_outfeed_request(); - request->set_outfeed_config(outfeed_config); - *request->mutable_operand() = operand; - *request->mutable_shape() = shape_with_layout; - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Call( - const Computation& computation, - tensorflow::gtl::ArraySlice operands) { - OpRequest op_request; - CallRequest* request = op_request.mutable_call_request(); - *request->mutable_to_apply() = computation.handle(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::CustomCall( - const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape) { - OpRequest op_request; - CustomCallRequest* request = op_request.mutable_custom_call_request(); - request->set_call_target_name(call_target_name); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_shape() = shape; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::HostCompute( - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { - OpRequest op_request; - HostComputeRequest* request = op_request.mutable_host_compute_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_shape() = shape; - request->set_channel_name(channel_name); - request->set_cost_estimate_ns(cost_estimate_ns); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Complex( - const ComputationDataHandle& real, const ComputationDataHandle& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Conj( - const ComputationDataHandle& operand) { - return Complex(Real(operand), Neg(Imag(operand))); -} - -ComputationDataHandle ComputationBuilder::Add( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Sub( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Mul( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Div( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Rem( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Max( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Min( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::And( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Or( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions); -} - -// TODO(b/65209188): Create a dedicated lowering for Xor -ComputationDataHandle ComputationBuilder::Xor( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return Or(And(Not(lhs), rhs, broadcast_dimensions), - And(lhs, Not(rhs), broadcast_dimensions)); -} - -ComputationDataHandle ComputationBuilder::Not( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_NOT, operand); -} - -ComputationDataHandle ComputationBuilder::ShiftLeft( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ShiftRightArithmetic( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ShiftRightLogical( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Abs( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_ABS, operand); -} - -ComputationDataHandle ComputationBuilder::Atan2( - const ComputationDataHandle& y, const ComputationDataHandle& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Exp( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_EXP, operand); -} - -ComputationDataHandle ComputationBuilder::Floor( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_FLOOR, operand); -} - -ComputationDataHandle ComputationBuilder::Ceil( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_CEIL, operand); -} - -ComputationDataHandle ComputationBuilder::Round( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand); -} - -ComputationDataHandle ComputationBuilder::Log( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_LOG, operand); -} - -ComputationDataHandle ComputationBuilder::Sign( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SIGN, operand); -} - -ComputationDataHandle ComputationBuilder::Cos( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_COS, operand); -} - -ComputationDataHandle ComputationBuilder::Sin( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SIN, operand); -} - -ComputationDataHandle ComputationBuilder::Tanh( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_TANH, operand); -} - -ComputationDataHandle ComputationBuilder::Real( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_REAL, operand); -} - -ComputationDataHandle ComputationBuilder::Imag( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_IMAG, operand); -} - -ComputationDataHandle ComputationBuilder::IsFinite( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_IS_FINITE, operand); -} - -ComputationDataHandle ComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - OpRequest op_request; - TransposeRequest* request = op_request.mutable_transpose_request(); - *request->mutable_operand() = operand; - for (int64 dimension : permutation) { - request->add_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - OpRequest op_request; - ReverseRequest* request = op_request.mutable_reverse_request(); - *request->mutable_operand() = operand; - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Sort( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SORT, operand); -} - -ComputationDataHandle ComputationBuilder::SqrtF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(0.5), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::Pow( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original = shape_status.ConsumeValueOrDie(); - - OpRequest op_request; - ConvertRequest* request = op_request.mutable_convert_request(); - *request->mutable_operand() = operand; - request->set_new_element_type(new_element_type); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BitcastConvertType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original = shape_status.ConsumeValueOrDie(); - - OpRequest op_request; - ConvertRequest* request = op_request.mutable_bitcast_convert_request(); - *request->mutable_operand() = operand; - request->set_new_element_type(new_element_type); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SquareF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(2.0), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::ReciprocalF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(-1.0), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::Neg( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_NEGATE, operand); -} - -ComputationDataHandle ComputationBuilder::Clz( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_CLZ, operand); -} - -ComputationDataHandle ComputationBuilder::Clamp( - const ComputationDataHandle& min, const ComputationDataHandle& operand, - const ComputationDataHandle& max) { - return TernaryOp(TRIOP_CLAMP, min, operand, max); -} - -ComputationDataHandle ComputationBuilder::UnaryOp( - UnaryOperation unop, const ComputationDataHandle& operand) { - OpRequest op_request; - UnaryOpRequest* request = op_request.mutable_unary_op_request(); - request->set_unop(unop); - *request->mutable_operand() = operand; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BinaryOp( - BinaryOperation binop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - OpRequest op_request; - BinaryOpRequest* request = op_request.mutable_binary_op_request(); - request->set_binop(binop); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - for (int64 dimension : broadcast_dimensions) { - request->add_broadcast_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::RngOp( - RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape) { - OpRequest op_request; - RngRequest* request = op_request.mutable_rng_request(); - request->set_distribution(distribution); - for (const ComputationDataHandle& param : parameters) { - *request->add_parameter() = param; - } - *request->mutable_shape() = shape; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::TernaryOp( - TernaryOperation triop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) { - OpRequest op_request; - TernaryOpRequest* request = op_request.mutable_ternary_op_request(); - request->set_triop(triop); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_ehs() = ehs; - return RunOpAndParseResponse(&op_request); -} - -Status ComputationBuilder::SetReturnValue( - const ComputationDataHandle& operand) { - TF_RETURN_IF_ERROR(first_error_); - - SetReturnValueRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - - SetReturnValueResponse response; - - VLOG(2) << "making set-handle-to-execute request"; - Status s = client_->stub()->SetReturnValue(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - NoteError(s); - return first_error_; - } - - return Status::OK(); -} - -StatusOr ComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - TF_RETURN_IF_ERROR(first_error_); - - IsConstantRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - request.set_num_parameters(num_parameters); - IsConstantResponse response; - - VLOG(2) << "making IsConstant request"; - Status s = client_->stub()->IsConstant(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - return response.is_constant(); -} - -StatusOr> ComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - TF_RETURN_IF_ERROR(first_error_); - - ComputeConstantRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - if (output_layout != nullptr) { - *request.mutable_output_layout() = *output_layout; - } - for (const auto& param : parameters) { - *request.add_parameters() = param.ToProto(); - } - - ComputeConstantResponse response; - - VLOG(2) << "making compute-constant request"; - Status s = client_->stub()->ComputeConstant(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - - VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return InternalError( - "no computed literal in the provided response in ComputeConstant " - "request"); - } - return Literal::CreateFromProto(response.literal()); -} - -ComputationDataHandle ComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, - const Computation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - OpRequest op_request; - MapRequest* request = op_request.mutable_map_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_to_apply() = computation.handle(); - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - for (const ComputationDataHandle& sop : static_operands) { - *request->add_static_operands() = sop; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); -} - -ComputationDataHandle ComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); -} - -ComputationDataHandle ComputationBuilder::While( - const Computation& condition, const Computation& body, - const ComputationDataHandle& init) { - OpRequest op_request; - WhileRequest* request = op_request.mutable_while_request(); - *request->mutable_condition() = condition.handle(); - *request->mutable_body() = body.handle(); - *request->mutable_init() = init; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Gather( - const ComputationDataHandle& input, - const ComputationDataHandle& gather_indices, - const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds) { - OpRequest op_request; - GatherRequest* gather_request = op_request.mutable_gather_request(); - *gather_request->mutable_input() = input; - *gather_request->mutable_gather_indices() = gather_indices; - *gather_request->mutable_dimension_numbers() = dimension_numbers; - for (int64 window_bound : window_bounds) { - gather_request->add_window_bounds(window_bound); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const Computation& true_computation, - const ComputationDataHandle& false_operand, - const Computation& false_computation) { - OpRequest op_request; - ConditionalRequest* request = op_request.mutable_conditional_request(); - *request->mutable_predicate() = predicate; - *request->mutable_true_operand() = true_operand; - *request->mutable_true_computation() = true_computation.handle(); - *request->mutable_false_operand() = false_operand; - *request->mutable_false_computation() = false_computation.handle(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { - OpRequest op_request; - ReduceRequest* request = op_request.mutable_reduce_request(); - *request->mutable_operand() = operand; - *request->mutable_init_value() = init_value; - for (int64 dimension : dimensions_to_reduce) { - request->add_dimensions(dimension); - } - *request->mutable_to_apply() = computation.handle(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ReduceAll( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - - std::vector all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie())); - std::iota(all_dimnos.begin(), all_dimnos.end(), 0); - return Reduce(operand, init_value, computation, all_dimnos); -} - -ComputationDataHandle ComputationBuilder::ReduceWindow( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - - Status padding_valid = - ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides); - if (!padding_valid.ok()) { - first_error_ = padding_valid; - return ComputationDataHandle(); - } - - std::vector> padding_values = - MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding); - return ReduceWindowWithGeneralPadding(operand, init_value, computation, - window_dimensions, window_strides, - padding_values); -} - -ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { - OpRequest op_request; - ReduceWindowRequest* request = op_request.mutable_reduce_window_request(); - *request->mutable_operand() = operand; - *request->mutable_to_apply() = computation.handle(); - *request->mutable_init_value() = init_value; - - if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request->mutable_window())) { - NoteError(InternalError("failed to make window")); - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormTraining( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, float epsilon, int64 feature_index) { - OpRequest op_request; - BatchNormTrainingRequest* request = - op_request.mutable_batch_norm_training_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_offset() = offset; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormInference( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, const ComputationDataHandle& mean, - const ComputationDataHandle& variance, float epsilon, int64 feature_index) { - OpRequest op_request; - BatchNormInferenceRequest* request = - op_request.mutable_batch_norm_inference_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_offset() = offset; - *request->mutable_mean() = mean; - *request->mutable_variance() = variance; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormGrad( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& batch_mean, - const ComputationDataHandle& batch_var, - const ComputationDataHandle& grad_output, float epsilon, - int64 feature_index) { - OpRequest op_request; - BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_mean() = batch_mean; - *request->mutable_variance() = batch_var; - *request->mutable_grad_output() = grad_output; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - OpRequest op_request; - CrossReplicaSumRequest* request = - op_request.mutable_cross_replica_sum_request(); - *request->mutable_operand() = operand; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SelectAndScatter( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - return SelectAndScatterWithGeneralPadding( - operand, select, window_dimensions, window_strides, - MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding), - source, init_value, scatter); -} - -ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter) { - OpRequest op_request; - SelectAndScatterRequest* request = - op_request.mutable_select_and_scatter_request(); - *request->mutable_operand() = operand; - *request->mutable_select() = select.handle(); - *request->mutable_source() = source; - *request->mutable_init_value() = init_value; - *request->mutable_scatter() = scatter.handle(); - - if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request->mutable_window())) { - NoteError(InternalError("failed to make window")); - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ReducePrecision( - const ComputationDataHandle& operand, const int exponent_bits, - const int mantissa_bits) { - OpRequest op_request; - ReducePrecisionRequest* request = - op_request.mutable_reduce_precision_request(); - *request->mutable_operand() = operand; - request->set_exponent_bits(exponent_bits); - request->set_mantissa_bits(mantissa_bits); - return RunOpAndParseResponse(&op_request); -} - -void ComputationBuilder::Send(const ComputationDataHandle& operand, - const ChannelHandle& handle) { - OpRequest op_request; - SendRequest* request = op_request.mutable_send_request(); - *request->mutable_operand() = operand; - *request->mutable_channel_handle() = handle; - *op_request.mutable_computation() = computation_.handle(); - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Recv(const Shape& shape, - const ChannelHandle& handle) { - OpRequest op_request; - RecvRequest* request = op_request.mutable_recv_request(); - *request->mutable_shape() = shape; - *request->mutable_channel_handle() = handle; - return RunOpAndParseResponse(&op_request); -} - -Computation ComputationBuilder::BuildAndNoteError() { - DCHECK(parent_builder_ != nullptr); - auto build_status = Build(); - if (!build_status.ok()) { - parent_builder_->NoteError( - AddStatus(build_status.status(), - tensorflow::strings::StrCat("error from: ", name_))); - return Computation(); - } - return build_status.ConsumeValueOrDie(); -} - -StatusOr ComputationBuilder::Build() { - if (!first_error_.ok()) { - string backtrace; - first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); - return AppendStatus(first_error_, backtrace); - } - - if (computation_.IsNull()) { - return FailedPrecondition("no computation was built"); - } - - return {std::move(computation_)}; -} - -/* static */ ConvolutionDimensionNumbers -ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { - ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_input_batch_dimension(kConvBatchDimension); - dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); - dimension_numbers.set_output_batch_dimension(kConvBatchDimension); - dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); - dimension_numbers.set_kernel_output_feature_dimension( - kConvKernelOutputDimension); - dimension_numbers.set_kernel_input_feature_dimension( - kConvKernelInputDimension); - for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_input_spatial_dimensions(i + 2); - dimension_numbers.add_kernel_spatial_dimensions(i + 2); - dimension_numbers.add_output_spatial_dimensions(i + 2); - } - return dimension_numbers; -} - -/* static */ StatusOr -ComputationBuilder::CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 input_first_spatial, - int64 input_second_spatial, int64 output_batch, int64 output_feature, - int64 output_first_spatial, int64 output_second_spatial, - int64 kernel_output_feature, int64 kernel_input_feature, - int64 kernel_first_spatial, int64 kernel_second_spatial) { - if (std::set({input_batch, input_feature, input_first_spatial, - input_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the input are not unique: (%lld, %lld, %lld, " - "%lld)", - input_batch, input_feature, input_first_spatial, input_second_spatial); - } - if (std::set({kernel_output_feature, kernel_input_feature, - kernel_first_spatial, kernel_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " - "%lld)", - kernel_output_feature, kernel_input_feature, kernel_first_spatial, - kernel_second_spatial); - } - if (std::set({output_batch, output_feature, output_first_spatial, - output_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the output are not unique: (%lld, %lld, %lld, " - "%lld)", - output_batch, output_feature, output_first_spatial, - output_second_spatial); - } - ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_input_batch_dimension(input_batch); - dimension_numbers.set_input_feature_dimension(input_feature); - dimension_numbers.add_input_spatial_dimensions(input_first_spatial); - dimension_numbers.add_input_spatial_dimensions(input_second_spatial); - dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); - dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); - dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); - dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); - dimension_numbers.set_output_batch_dimension(output_batch); - dimension_numbers.set_output_feature_dimension(output_feature); - dimension_numbers.add_output_spatial_dimensions(output_first_spatial); - dimension_numbers.add_output_spatial_dimensions(output_second_spatial); - return dimension_numbers; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h deleted file mode 100644 index ac1eb915cc52df94df71631a7e80de9095f7fafb..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ /dev/null @@ -1,1067 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/bitmap.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/stacktrace.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Wraps an XLA client with a convenient interface for building up -// computations. Any errors encountered in building up the computation are -// deferred from being handled until Build() is called. -// -// Thread-compatible. -// -// TODO(b/74197823): Deprecated. Use XlaBuilder instead. -class ComputationBuilder { - public: - // client: client in which to build the computation. - // computation_name: name to use for the built computation. - ComputationBuilder(Client* client, const string& computation_name); - - ~ComputationBuilder(); - - // Returns the client the builder was initialized with. - Client* client() const { return client_; } - - // Returns the computation name. - const string& name() const { return name_; } - - // Sets OpMetadata that will be added to all instructions until cleared. - // - // OpMetadata is often applied to a series of XLA HLO instructions. As a - // result, OpMetadata is set on the Computation Builder. All subsequent - // instructions generated via this Computation Builder will have the same - // OpMetadata attached until a call to ClearOpMetadata. - void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } - - // Clears the HloMetadata state. - void ClearOpMetadata() { metadata_.Clear(); } - - // Sets an OpSharding that will be attached to all instructions until cleared. - void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } - - // Clears the sharding. Ops will be sharded according to the default placement - // policy. - void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } - - // Returns the OpSharding that will be attached to all instructions. - const tensorflow::gtl::optional& sharding() const { - return sharding_; - } - - // Sets the builder to a mode where it will die immediately when an error is - // encountered, rather than producing it in a deferred fashion when Build() is - // called (which is the default). - void set_die_immediately_on_error(bool enabled) { - die_immediately_on_error_ = enabled; - } - - // Enqueues a "retrieve parameter value" instruction for a parameter that was - // passed to the computation. - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); - - // Retrieves the (inferred) shape of the operand in the computation. - StatusOr> GetShape( - const ComputationDataHandle& operand); - - // Retrieves the (inferred) result for the current computation's shape. - StatusOr GetProgramShape(); - - // Enqueues a constant with the value of the given literal onto the - // computation. - ComputationDataHandle ConstantLiteral(const Literal& literal); - - // Enqueues a constant onto the computation. Methods are templated on the - // native host type (NativeT) which corresponds to a specific XLA - // PrimitiveType as given in the following table: - // - // Native Type PrimitiveType - // ----------------------------- - // bool PRED - // int32 S32 - // int64 S64 - // uint32 U32 - // uint64 U64 - // float F32 - // double F64 - // - // Note: not all primitive types defined in xla_data.proto have a - // corresponding native type yet. - template - ComputationDataHandle ConstantR0(NativeT value); - template - ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice values); - ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values); - template - ComputationDataHandle ConstantR2( - std::initializer_list> values); - template - ComputationDataHandle ConstantFromArrayWithLayout( - const Array& values, const Layout& layout); - template - ComputationDataHandle ConstantFromArray(const Array& values); - template - ComputationDataHandle ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout); - template - ComputationDataHandle ConstantR2FromArray2D(const Array2D& values); - template - ComputationDataHandle ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout); - template - ComputationDataHandle ConstantR3FromArray3D(const Array3D& values); - template - ComputationDataHandle ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout); - template - ComputationDataHandle ConstantR4FromArray4D(const Array4D& values); - - // Enqueues a rank one constant (vector) onto the computation. The vector has - // size 'length' and every element has the value 'value'. - template - ComputationDataHandle ConstantR1(int64 length, NativeT value); - - // Adds dimensions to an array by duplicating the data in the array. - // - // The new dimensions are inserted on the left, i.e. if - // broadcast_sizes has values {a0, ..., aN} and the operand shape - // has dimensions {b0, ..., bM} then the shape of the output has - // dimensions {a0, ..., aN, b0, ..., bM}. - // - // The new dimensions index into copies of the operand, i.e. - // - // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); - - // Enqueues a pad operation onto the computation that pads the given value on - // the edges as well as between the elements of the input. padding_config - // specifies the padding amount for each dimension. - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); - - // Enqueues an operation onto the computation that flattens the operand based - // on the dimension order (major/slowest-varying to minor/fastest-varying) - // given, followed by reshaping it into the shape with the given dimension - // sizes (also major to minor). Conceptually, this is a limited form of - // "shape casting". - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); - - // Enqueues an operation onto the computation that collapses the operand, from - // first to last dimension (C order), then reshapes it to the given dimension - // sizes. Conceptually, this is a limited form of "shape casting". - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice new_sizes); - - // Wrapper for Reshape. - // Enqueues an operation to collapse the provided dimensions; e.g. an - // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to - // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must - // be a consecutive, in-order subsequence of the operand dimensions. - // - // Note that collapsing a single dimension does nothing: - // - // {256} collapsing {0} => {256} - // {1} collapsing {0} => {1} - // - // Collapsing multiple dimensions produces a single result dimension: - // - // {256, 2} collapsing {0,1} => {512} - // {256, 2, 3} collapsing {0,1} => {512, 3} - // - // This could potentially cause data to be moved -- it provides a more - // structured form of reshaping than an arbitrary Reshape operation. - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a slice operation onto the computation that slices the operand - // from the start indices to the limit indices; e.g. - // - // x - // [ 0 1 2 3 ] - // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] - // [ 8 9 a b ] - // - // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D - // range notation. - // The strides parameter determines the stride over the slice - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); - - // Enqueues a slice operation in a given dimension, taking all other - // dimensions as they are; e.g. if dimno is 1 from start_index 2 to - // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand - // for: - // - // array[:, 2:4:1, :] - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); - - // Enqueues a slice operation onto the computation that slices the 'operand' - // from dynamic start indices which are passed in 'start_indices'. - // The size of the slice in each dimension is passed in 'slice_sizes', - // which specify the end point of exclusive slice intervals in each - // dimension [start, start + size). - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo input dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); - - // Enqueues a dynamic update slice operation onto the computation, which - // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. - // The shape of 'update' determines the shape of the slice of 'operand' - // which is updated. - // The indices specified in 'start_indices' specify the offset of the slice - // of 'operand' which is updated. - // - // update = {10, 11} // calculated at runtime. - // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] - // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] - // [7 8 9] [7 8 9 ] - // - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo update dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); - - // Enqueues a concatenate instruction onto the computation. 'operands' must - // have >= 1 entry. - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); - - // Enqueue a tracing operation onto the computation; the computation will emit - // a logging message with the operand. - void Trace(const string& tag, const ComputationDataHandle& operand); - - // Enqueues a conditional-move-like select operation onto the computation; - // predicated on pred, selects between on_true and on_false. - ComputationDataHandle Select(const ComputationDataHandle& pred, - const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false); - - // Enqueues a tuple-creation instruction onto the computation. - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); - - // Enqueues a tuple-element-get instruction onto the computation. - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); - - // Enqueues an equal-to comparison instruction onto the computation. - ComputationDataHandle Eq( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a not-equal comparison instruction onto the computation. - ComputationDataHandle Ne( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-or-equal comparison instruction onto the computation. - ComputationDataHandle Ge( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-than comparison instruction onto the computation. - ComputationDataHandle Gt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-than comparison instruction onto the computation. - ComputationDataHandle Lt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-or-equal comparison instruction onto the computation. - ComputationDataHandle Le( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a dot instruction onto the computation. - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); - - // Enqueues a general dot instruction onto the computation. - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); - - // Default dimension numbers used for a 2D convolution. - static constexpr int64 kConvBatchDimension = 0; - static constexpr int64 kConvFeatureDimension = 1; - static constexpr int64 kConvFirstSpatialDimension = 2; - static constexpr int64 kConvSecondSpatialDimension = 3; - static constexpr int64 kConvKernelOutputDimension = 0; - static constexpr int64 kConvKernelInputDimension = 1; - static constexpr int64 kConvKernelFirstSpatialDimension = 2; - static constexpr int64 kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an - // error if either the input or the weight dimension numbers have conflicts. - static StatusOr CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 input_first_spatial, - int64 input_second_spatial, int64 output_batch, int64 output_feature, - int64 output_first_spatial, int64 output_second_spatial, - int64 kernel_output_feature, int64 kernel_input_feature, - int64 kernel_first_spatial, int64 kernel_second_spatial); - - // Enqueues a convolution instruction onto the computation, which uses the - // default convolution dimension numbers. - ComputationDataHandle Conv(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration in the format returned by MakePadding(). - ComputationDataHandle ConvWithGeneralPadding( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided dimension numbers configuration. - ComputationDataHandle ConvWithGeneralDimensions( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration as well as the dimension numbers. - ComputationDataHandle ConvGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration, dilation factors and dimension numbers. - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues an FFT instruction onto the computation, of the given type and - // with the given FFT length. - ComputationDataHandle Fft(const ComputationDataHandle& operand, - FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); - - // Enqueues an infeed instruction onto the computation, which writes data of - // the given shape to the infeed buffer of the device. - ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); - - // Enqueues an outfeed instruction onto the computation. This instruction - // generates outgoing data transfers for the given data. - // - // shape_with_layout communicates the laid out shape that we want to outfeed - // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error - // will occur. - void Outfeed(const ComputationDataHandle& operand, - const Shape& shape_with_layout, const string& outfeed_config); - - // Enqueues a call instruction onto the computation. - ComputationDataHandle Call( - const Computation& computation, - tensorflow::gtl::ArraySlice operands); - - // Enqueues a custom call instruction onto the computation. - // During code generation, a call instruction is emitted which targets a - // symbol with the name |call_target_name|. The |operands| are passed to the - // call instruction. |shape| is the resultant shape. - ComputationDataHandle CustomCall( - const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - - // Enqueues a pseudo-op to represent host-side computation data-dependencies. - // During code generation, host send and receive operations will be generated - // to transfer |operands| to the host and a single result of |shape| back to - // the device. Host send/recv operations are emitted using |channel_name|. - // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO - // instruction scheduling. - ComputationDataHandle HostCompute( - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, const Shape& shape); - - // The following methods enqueue element-wise binary arithmetic operations - // onto the computation. The shapes of the operands have to match unless one - // of the operands is a scalar, or an explicit broadcast dimension is given - // (see g3doc for more details). - - // Enqueues a complex compose instruction onto the computation. - ComputationDataHandle Complex( - const ComputationDataHandle& real, const ComputationDataHandle& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a complex conjugate instruction onto the computation. - ComputationDataHandle Conj(const ComputationDataHandle& operand); - - // Enqueues an add instruction onto the computation. - ComputationDataHandle Add( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a subtract instruction onto the computation. - ComputationDataHandle Sub( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a multiply instruction onto the computation. - ComputationDataHandle Mul( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a divide instruction onto the computation. - ComputationDataHandle Div( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a remainder instruction onto the computation. - ComputationDataHandle Rem( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a max instruction onto the computation. - ComputationDataHandle Max( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a min instruction onto the computation. - ComputationDataHandle Min( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Element-wise logical operators - ComputationDataHandle And( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Or( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Xor( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Not(const ComputationDataHandle& operand); - - ComputationDataHandle ShiftLeft( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle ShiftRightArithmetic( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle ShiftRightLogical( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Reduces an array among the provided dimensions, given "computation" as a - // reduction operator. - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); - - // Convenience wrapper around the above that reduces all the dimensions in the - // operand shape. - ComputationDataHandle ReduceAll(const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const Computation& computation); - - // Enqueues a windowed reduce instruction onto the computation. - ComputationDataHandle ReduceWindow( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); - - // As ReduceWindow(), but the padding is given in the format - // returned by MakePadding(). - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Returns the sum of the operand value across all replicas. All replicas - // supply one input to the sum and all replicas receive the resulting sum. - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); - - // Enqueues an operation that scatters the `source` array to the selected - // indices of each window. - ComputationDataHandle SelectAndScatter( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter); - - // As SelectAndScatter(), but the padding is given in the format - // returned by MakePadding(). - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter); - - // Enqueues an abs instruction onto the computation. - ComputationDataHandle Abs(const ComputationDataHandle& operand); - - // Enqueues a atan2 instruction onto the computation. - ComputationDataHandle Atan2( - const ComputationDataHandle& y, const ComputationDataHandle& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an exp instruction onto the computation. - ComputationDataHandle Exp(const ComputationDataHandle& operand); - - // Enqueues a floor instruction onto the computation. - ComputationDataHandle Floor(const ComputationDataHandle& operand); - - // Enqueues a ceil instruction onto the computation. - ComputationDataHandle Ceil(const ComputationDataHandle& operand); - - // Enqueues a round instruction onto the computation, rounding to nearest even - // with half-way cases rounding away from zero. - ComputationDataHandle Round(const ComputationDataHandle& operand); - - // Enqueues an log instruction (natural logarithm) onto the computation. - ComputationDataHandle Log(const ComputationDataHandle& operand); - - // Enqueues a sign instruction onto the computation. - ComputationDataHandle Sign(const ComputationDataHandle& operand); - - // Enqueues a cosine instruction onto the computation. - ComputationDataHandle Cos(const ComputationDataHandle& operand); - - // Enqueues a sine instruction onto the computation. - ComputationDataHandle Sin(const ComputationDataHandle& operand); - - // Enqueues a tanh instruction onto the computation. - ComputationDataHandle Tanh(const ComputationDataHandle& operand); - - // Enqueues a real-part instruction onto the computation. - ComputationDataHandle Real(const ComputationDataHandle& operand); - - // Enqueues an imaginary-part instruction onto the computation. - ComputationDataHandle Imag(const ComputationDataHandle& operand); - - // Enqueues a float32 sqrt instruction onto the computation. - // (float32 is specified as there is an implicit float32 0.5f constant - // exponent). - ComputationDataHandle SqrtF32(const ComputationDataHandle& operand); - - // Enqueues a float32 square instruction onto the computation. - // (float32 is specified as there is an implicit float32 2.0f constant - // exponent). - ComputationDataHandle SquareF32(const ComputationDataHandle& operand); - - // Enqueues a lhs^rhs computation onto the computation. - ComputationDataHandle Pow( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an operator that tests if the operand's values are finite, i.e., - // not Inf or NaN. Defined only for floating-point types. Returns an array of - // booleans with the same shape where entries are true iff the corresponding - // entry was NaN. - ComputationDataHandle IsFinite(const ComputationDataHandle& operand); - - // Enqueues a convert instruction onto the computation that changes the - // element type of the operand array to primitive_type. - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); - - // Enqueues a no-op instruction onto the computation that changes - // the element type of the operand array to primitive_type. The - // bit-widths of the source and destination element types must be - // identical. - ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); - - // Enqueues a float32 reciprocal instruction onto the computation. - // (float32 is specified as there is an implicit float32 -1.0f constant - // exponent). - // - // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the - // shape of the operand. - ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand); - - // Enqueues a negate instruction onto the computation. - ComputationDataHandle Neg(const ComputationDataHandle& operand); - - // Enqueues a count-leading-zeros instruction onto the computation. - ComputationDataHandle Clz(const ComputationDataHandle& operand); - - // Enqueues a transpose instruction onto the computation. - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); - - // Enqueues a reverse instruction onto the computation. The order of the - // elements in the given dimensions is reversed (i.e., the element at index i - // is moved to index dimension_size - 1 - i). - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a sort (as increasing order) instruction onto the computation. - ComputationDataHandle Sort(const ComputationDataHandle& operand); - - // Enqueues a clamp instruction onto the computation. - ComputationDataHandle Clamp(const ComputationDataHandle& min, - const ComputationDataHandle& operand, - const ComputationDataHandle& max); - - // Enqueues a map instruction onto the computation. - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const Computation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); - - // Enqueues a N(mu, sigma) random number generation instruction onto the - // computation. - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); - - // Enqueues a U(a, b) random number generation instruction onto the - // computation. Returns values in the semi-open interval [a, b). - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); - - // Enqueues a while node onto the computation. - ComputationDataHandle While(const Computation& condition, - const Computation& body, - const ComputationDataHandle& init); - - // Enqueues a conditional node onto the computation. - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const Computation& true_computation, - const ComputationDataHandle& false_operand, - const Computation& false_computation); - - // Enqueues a ReducePrecision node onto the computation. - ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, - const int exponent_bits, - const int mantissa_bits); - - // Enqueues a Gather node onto the computation. - ComputationDataHandle Gather( - const ComputationDataHandle& input, - const ComputationDataHandle& gather_indices, - const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); - - // Enqueues a Send node onto the computation, to send the given operand to - // a Recv instruction that shares the same channel handle. - void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); - - // Enqueues a Recv node onto the computation. The data comes from a Send - // instruction that shares the same channel handle and its shape must - // be the same as the given shape. - ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); - - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters with index greater than or equal to - // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. - // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a - // compile-time constant without evaluating the computation. - StatusOr IsConstant(const ComputationDataHandle& operand, - int64 num_parameters = 0); - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` - // is the normalized result and batch_mean and batch_var are the mean and - // variance, respectively, across batch for the operand. - ComputationDataHandle BatchNormTraining(const ComputationDataHandle& operand, - const ComputationDataHandle& scale, - const ComputationDataHandle& offset, - float epsilon, int64 feature_index); - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // `BatchNormInference` is equivalent to calling `BatchNormTraining` without - // computing `mean` and `variance` for each batch inside the operation. It - // uses the input `mean` and `variance` instead as estimated values. The - // purpose of this op is to reduce latency in inference, hence the name - // `BatchNormInference`. - // - // The output has the same shape as `operand`, and contains the normalized - // values for each batch. - ComputationDataHandle BatchNormInference( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, const ComputationDataHandle& mean, - const ComputationDataHandle& variance, float epsilon, - int64 feature_index); - - // Calculates the gradients of a batch norm op. - // - // The inputs `batch_mean` and `batch_var` represent the mean and variance - // across the batch. - // - // Returns a tuple of three elements: - // - grad_operand: Gradient with respect to input `operand` - // - grad_offset: Gradient with respect to input `offset` - // - grad_scale: Gradient with respect to input `scale` - ComputationDataHandle BatchNormGrad(const ComputationDataHandle& operand, - const ComputationDataHandle& scale, - const ComputationDataHandle& batch_mean, - const ComputationDataHandle& batch_var, - const ComputationDataHandle& grad_output, - float epsilon, int64 feature_index); - - // Computes the value of a constant indicated by a - // ComputationDataHandle using a non-optimized interpreter on the host. - // - // The operand must be from the computation currently being built - - // i.e., returned from this builder with no intervening call to - // Build(). This happens to currently work regardless of that, but - // that may stop working at any time. - // - // The operand must represent a constant value, which in this case - // means that it must not statically depend on any parameter of the - // computation that is being built other then the ones specified on the - // parameter list. The parameters in the list will be indexed by their - // parameter id property so the number of parameters specified should be at - // least as many as the largest used parameter index. - // - // `IsConstant` can be used to test whether a computation is a compile-time - // constant without evaluation it. `ComputeConstant` only succeeds for - // computations where `IsConstant` returns true. - // - // This functionality can be useful when translating a computation - // into XLA where something that looked dynamic is required by - // XLA to be specified as a constant. E.g. the source - // computation (outside of XLA) may include a dynamic - // computation of the shape of something and ComputeConstant lets - // you determine what the value of that computation is in the case - // where the value can be determined at compile time. - // - // If output_layout is non-null, then the output of the computation - // will be stored using that layout. - StatusOr> ComputeConstant( - const ComputationDataHandle& operand, - const Layout* output_layout = nullptr, - tensorflow::gtl::ArraySlice parameters = {}); - - // Returns a new ComputationBuilder whose resultant Computation is used only - // by this ComputationBuilder. The sub-ComputationBuilder has the same - // die_immediately_on_error behavior as the parent. - std::unique_ptr CreateSubBuilder( - const string& computation_name); - - // Modifies the computation being built so that executions of it - // will return the value associated with operand, rather than the - // last expression enqueued on the ComputationBuilder. Any subsequent - // operations added to the ComputationBuilder will not have any effect unless - // SetReturnValue is called again. - Status SetReturnValue(const ComputationDataHandle& operand); - - // Builds the computation with the requested operations, or returns a non-ok - // status. - StatusOr Build(); - - // Builds the computation with the requested operations, or notes an error in - // the parent ComputationBuilder and returns an empty computation if building - // failed. This function is intended to be used where the returned - // Computation is only used by the parent ComputationBuilder and hence further - // operation on the returned Computation will simply be error'ed out if an - // error occurred while building this computation. If the built computation is - // to be used by a ComputationBuilder other than the parent ComputationBuilder - // then Build() should be used instead. - Computation BuildAndNoteError(); - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // ComputationDataHandle and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - Status first_error() const { return first_error_; } - - private: - // Limited checking of convolution parameters. Returns false on - // error. - bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers); - - // The parent ComputationBuilder of a sub-ComputationBuilder. The - // parent_builder_ will be the nullptr if not a sub-ComputationBuilder. - ComputationBuilder* parent_builder_{nullptr}; - - // Helper function for creating a Window proto from user-supplied - // data. Returns true if the user-supplied data was valid. - bool MakeWindow(tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - Window* window); - - // Internal helper method that does the building for an arbitrary unary op. - ComputationDataHandle UnaryOp(UnaryOperation unop, - const ComputationDataHandle& operand); - - // Internal helper method that does the building for an arbitrary binary op. - // broadcast_dimensions specifies which dimensions to use for broadcasting - // when the operation is between tensors of different ranks. - ComputationDataHandle BinaryOp( - BinaryOperation binop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - - // Internal helper method that does the building for an arbitrary ternary op. - ComputationDataHandle TernaryOp(TernaryOperation triop, - const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - const ComputationDataHandle& ehs); - - // Internal helper method that does the building for a random number generator - // of a given distribution with an explicitly specified shape. - ComputationDataHandle RngOp( - RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape); - - // Populates computation_ with a valid object or returns a failing status. - // This is used before any given operation is enqueued. - Status PrepareComputation(); - - // Notes that the error occurred by: - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to Build()) - // * dying if die_immediately_on_error_ is true - void NoteError(const Status& error); - - // Helper function that runs the given op_request, filling in op_response. - // Before the op is run, PrepareComputation is called, and common fields in - // the op_request are filled in. - Status RunOp(OpRequest* op_request, OpResponse* op_response); - - // Helper function that calls RunOp and calls NoteError on failures. - void RunOpAndNoteError(OpRequest* op_request); - - // Helper function that calls RunOp and either returns the output computation - // data handle (on success) or a vacuous computation data handle (on failure). - ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request); - - // Helper function that implements GetShape without noting errors. This makes - // it easier to ensure the real GetShape will note errors on every error path. - StatusOr> GetShapeWithoutNoteError( - const ComputationDataHandle& operand); - - string name_; // Name to use for the built computation. - - // The first error encountered while building the computation. - // This is OK until the first error is encountered. - Status first_error_; - - // The saved stack trace from the point at which the first error occurred. - tensorflow::SavedStackTrace first_error_backtrace_; - - // The computation that operations are enqueued onto. - Computation computation_; - - // The client that the computation is created in. Not owned. - Client* client_; - - // Mode bit that indicates whether to die when a first error is encountered. - bool die_immediately_on_error_ = false; - - // The metadata to attach to each op. This is structured as a "modal"-like - // operation, in order to simplify client code (and not sprinkle this metadata - // throughout the TensorFlow op kernel implementations). - OpMetadata metadata_; - - // Sharding for this operator. This is structured as a "model"-like operation, - // in order to simplify client code, similar to metadata_. - tensorflow::gtl::optional sharding_; - - TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder); -}; - -template -ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*Literal::CreateR0(value)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR1( - tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, - NativeT value) { - Literal literal(ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {length})); - literal.PopulateWithValue(value); - return ConstantLiteral(literal); -} - -inline ComputationDataHandle ComputationBuilder::ConstantR1( - const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2( - std::initializer_list> values) { - return ConstantLiteral(*Literal::CreateR2(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( - const Array& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantFromArray( - const Array& values) { - return ConstantLiteral(*Literal::CreateFromArray(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( - const Array2D& values) { - return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateR3FromArray3DWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( - const Array3D& values) { - return ConstantFromArray(values); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( - const Array4D& values) { - return ConstantFromArray(values); -} - -// RAII-style object: sets the current sharding assignment in builder on -// construction, and sets back to the previous assignment on destruction. -class ScopedShardingAssignment { - public: - ScopedShardingAssignment(xla::ComputationBuilder* builder, - tensorflow::gtl::optional sharding) - : builder_(builder), prev_sharding_(builder->sharding()) { - SetSharding(sharding); - } - - ~ScopedShardingAssignment() { SetSharding(prev_sharding_); } - - private: - void SetSharding(const tensorflow::gtl::optional& sharding) { - if (sharding.has_value()) { - builder_->SetSharding(sharding.value()); - } else { - builder_->ClearSharding(); - } - } - - xla::ComputationBuilder* const builder_; - tensorflow::gtl::optional prev_sharding_; - - TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 6e3c5cb484b8f1ef053fa287a4d462aeb886e530..7dee41f6a05025ec196b78e54015e8e71777031f 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -87,6 +87,18 @@ ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { return dump_optimized_hlo_proto_to_; } +ExecutableBuildOptions& +ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_unoptimized_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { + return dump_unoptimized_hlo_proto_to_; +} + ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( tensorflow::StringPiece dirpath) { dump_per_pass_hlo_proto_to_ = dirpath.ToString(); diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 11f10983606fe02b1edb11a260edde8e5f9a726f..9dc9be4423564fb967b247c2d1df31099cb80237 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -64,6 +65,13 @@ class ExecutableBuildOptions { tensorflow::StringPiece dirpath); const tensorflow::gtl::optional& dump_optimized_hlo_proto_to() const; + // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO + // protobuf to (as in DebugOptions). + ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_unoptimized_hlo_proto_to() + const; + // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs // to (as in DebugOptions). ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( @@ -76,6 +84,13 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_hlo_profile(bool enabled); tensorflow::gtl::optional hlo_profile() const; + void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) { + disabled_hlo_passes_.push_back(std::string(pass_name)); + } + const tensorflow::gtl::ArraySlice disabled_hlo_passes() const { + return disabled_hlo_passes_; + } + // Returns a string representation of the build options, suitable for // debugging. string ToString() const; @@ -87,8 +102,10 @@ class ExecutableBuildOptions { bool result_layout_set_ = false; tensorflow::gtl::optional generate_hlo_graph_; tensorflow::gtl::optional dump_optimized_hlo_proto_to_; + tensorflow::gtl::optional dump_unoptimized_hlo_proto_to_; tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; + std::vector disabled_hlo_passes_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc index 40f59eaa68ebeb47edbd2afbeabad0cd2623ebc6..2986d4060013703873b2cffb6aacbb012606d16f 100644 --- a/tensorflow/compiler/xla/client/global_data.cc +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -31,7 +31,7 @@ GlobalData::~GlobalData() { *request.mutable_data() = handle_; UnregisterResponse response; VLOG(1) << "requesting to unregister " << handle_.ShortDebugString(); - tensorflow::Status s = parent_->Unregister(&request, &response); + Status s = parent_->Unregister(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 9cd87f74735ff50df8a3382723c7d045ff6c9e52..3380af9f303b1dc2cec09aa37410ec40cdeaa526 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -92,21 +92,6 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, return MakeFakeDataViaDeviceOrDie(shape, client); } -std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client) { - auto program_shape = - client->GetComputationShape(computation).ConsumeValueOrDie(); - - // For every (unbound) parameter that the computation wants, we manufacture - // some arbitrary data so that we can invoke the computation. - std::vector> fake_arguments; - for (const Shape& parameter : program_shape->parameters()) { - fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); - } - - return fake_arguments; -} - std::vector> MakeFakeArgumentsOrDie( const XlaComputation& computation, Client* client) { CHECK(computation.proto().has_program_shape()) diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 9e06141b1f13d24cd033b72e31ee3a0442fe6a37..dc613099e2b42a60d0c11a654ab5cd41f8bd4f6f 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -32,12 +32,6 @@ namespace xla { std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client); -// Returns vector of GlobalData handles of fake data (created using -// MakeFakeDataOrDie) that are correctly shaped arguments for the given -// computation. -std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client); - // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given // xla computation. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1acc6f86860e526b5ff737c45041a863f21da145..ae0308020d014e038d2f0fd7de6c5f372d6cbed1 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -48,7 +48,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, << "Must have a valid device ordinal that the executable was built for."; } -tensorflow::Status LocalExecutable::ValidateExecutionOptions( +Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { const ComputationLayout& host_computation_layout = @@ -185,7 +185,7 @@ StatusOr LocalExecutable::Run( run_options, backend_->StreamBorrower(), backend_->eigen_intra_op_thread_pool()); - if (executable_->dumping()) { + if (executable_->dumping_snapshot()) { return ExecuteAndDump(&service_options, arguments); } return executable_->ExecuteOnStreamWrapper( @@ -195,36 +195,36 @@ StatusOr LocalExecutable::Run( StatusOr LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments) { - executable_->session_module()->set_execution_platform( + executable_->hlo_snapshot()->set_execution_platform( backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module())); + TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot())); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer result, executable_->ExecuteOnStream(run_options, arguments, /*hlo_execution_profile=*/nullptr)); - TF_RETURN_IF_ERROR(RecordResult(&result, executable_->session_module())); - TF_RETURN_IF_ERROR(executable_->DumpSessionModule()); + TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot())); + TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot()); return std::move(result); } -tensorflow::Status LocalExecutable::RecordArguments( +Status LocalExecutable::RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module) { - session_module->clear_arguments(); + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*argument)); - *session_module->add_arguments() = literal->ToProto(); + *hlo_snapshot->add_arguments() = literal->ToProto(); } return Status::OK(); } -tensorflow::Status LocalExecutable::RecordResult( - const ShapedBuffer* result, SessionModule* session_module) { - session_module->clear_result(); +Status LocalExecutable::RecordResult(const ShapedBuffer* result, + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_result(); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*result)); - *session_module->mutable_result() = literal->ToProto(); + *hlo_snapshot->mutable_result() = literal->ToProto(); return Status::OK(); } @@ -261,25 +261,6 @@ Backend* LocalClient::mutable_backend() { return local_service_->mutable_backend(); } -StatusOr> LocalClient::Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options) { - ExecutableBuildOptions updated_options = options; - if (options.device_ordinal() == -1) { - updated_options.set_device_ordinal(default_device_ordinal()); - VLOG(3) << "Set device ordinal to default value of: " - << updated_options.device_ordinal(); - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - local_service_->CompileExecutable(computation.handle(), argument_layouts, - updated_options)); - return WrapUnique(new LocalExecutable(std::move(executable), - local_service_->mutable_backend(), - updated_options)); -} - StatusOr> LocalClient::Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -323,6 +304,11 @@ StatusOr> LocalClient::ShapedBufferToLiteral( shaped_buffer); } +StatusOr LocalClient::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + return local_service_->GlobalDataToShapedBuffer(data, replica_number); +} + Status LocalClient::TransferToInfeedLocal(const Literal& literal, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index d8fd7a5623d1fecdcff6851aa3e3538822fb50da..4d9e0d7cd9d6ddebead1e12b23e94b529038039b 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -19,13 +19,13 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -59,25 +59,30 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. - tensorflow::Status ValidateExecutionOptions( + // + // The given ExecutableRunOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). + Status ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. + // + // The given ServiceExecutableRunOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments); // Records the arguments used to invoke the computation in a SessionModule // proto. - tensorflow::Status RecordArguments( + Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module); + HloSnapshot* hlo_snapshot); // Records the result of the computation in a SessionModule proto. - tensorflow::Status RecordResult(const ShapedBuffer* result, - SessionModule* session_module); + Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); // Returns a literal containing the contents of the given ShapedBuffer. StatusOr> LiteralFromShapedBuffer( @@ -108,17 +113,11 @@ class LocalClient : public Client { LocalClient(const LocalClient&) = delete; void operator=(const LocalClient&) = delete; - // Build and return a LocalExecutable object. The executable is compiled using - // the given argument layouts and options. - StatusOr> Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. + // The given ExecutableBuildOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). StatusOr> Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -137,6 +136,11 @@ class LocalClient : public Client { StatusOr> ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + // Transfer the given literal to the infeed queue of the given device. // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index 0d6e207971ec64515ec5e6da292910920edd101a..507a2dc5f088e159156f0ef3d663ba2819f6a2d4 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -37,7 +37,6 @@ cc_library( ], ) -# TODO(b/74197823): Replace computation_builder with xla_builder. cc_library( name = "xla_builder", srcs = ["xla_builder.cc"], diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 1899983e442116d3ebf8a3e79b0515653cd624cb..5e17cc4dfb0b225712e94041970545ff19f03b98 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -57,16 +57,6 @@ bool CanBeRoot(HloOpcode opcode) { } } -StatusOr> GetOperandShapes( - tensorflow::gtl::ArraySlice operands) { - std::vector operand_shapes; - for (const XlaOp& operand : operands) { - TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape()); - operand_shapes.push_back(shape); - } - return operand_shapes; -} - } // namespace StatusOr XlaBuilder::GetShape(const XlaOp& op) const { @@ -76,12 +66,14 @@ StatusOr XlaBuilder::GetShape(const XlaOp& op) const { return instr->shape(); } -StatusOr XlaOp::GetShape() const { - if (builder_ == nullptr) { - return InvalidArgument( - "cannot GetShape for an invalid XlaOp with handle %lld", handle()); +StatusOr> XlaBuilder::GetOperandShapes( + tensorflow::gtl::ArraySlice operands) const { + std::vector operand_shapes; + for (const XlaOp& operand : operands) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + operand_shapes.push_back(shape); } - return builder_->GetShape(*this); + return operand_shapes; } XlaBuilder::XlaBuilder(const string& computation_name) @@ -286,7 +278,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, const XlaOp& operand) { TF_RETURN_IF_ERROR(first_error_); - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); CHECK(ShapeUtil::IsScalar(operand_shape) || ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); @@ -325,7 +317,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferUnaryOpShape(unop, operand_shape)); return AddInstruction(std::move(instr), unop, {operand}); @@ -337,8 +329,8 @@ XlaOp XlaBuilder::BinaryOp( tensorflow::gtl::ArraySlice broadcast_dimensions) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferBinaryOpShape( binop, lhs_shape, rhs_shape, broadcast_dimensions)); @@ -374,12 +366,12 @@ XlaOp XlaBuilder::BinaryOp( updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs; } - TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, updated_lhs.GetShape()); + TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs)); if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(instr.shape(), updated_lhs)); } - TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, updated_rhs.GetShape()); + TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs)); if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(instr.shape(), updated_rhs)); @@ -393,9 +385,9 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, ehs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferTernaryOpShape( triop, lhs_shape, rhs_shape, ehs_shape)); @@ -437,7 +429,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } -XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) { +XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = literal.shape(); @@ -485,7 +477,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, XlaOp XlaBuilder::Broadcast( const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( const Shape& shape, ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes)); @@ -633,7 +625,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, ShapeInference::InferReshapeShape( operand_shape, dimensions, new_sizes)); @@ -647,7 +639,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice new_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(auto shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); return Reshape(operand, dimensions, new_sizes); @@ -1002,7 +994,7 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, const tensorflow::gtl::ArraySlice fft_length) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferFftShape(operand_shape, fft_type, fft_length)); @@ -1173,6 +1165,10 @@ XlaOp XlaBuilder::Exp(const XlaOp& operand) { return UnaryOp(HloOpcode::kExp, operand); } +XlaOp XlaBuilder::Expm1(const XlaOp& operand) { + return UnaryOp(HloOpcode::kExpm1, operand); +} + XlaOp XlaBuilder::Floor(const XlaOp& operand) { return UnaryOp(HloOpcode::kFloor, operand); } @@ -1189,6 +1185,10 @@ XlaOp XlaBuilder::Log(const XlaOp& operand) { return UnaryOp(HloOpcode::kLog, operand); } +XlaOp XlaBuilder::Log1p(const XlaOp& operand) { + return UnaryOp(HloOpcode::kLog1p, operand); +} + XlaOp XlaBuilder::Sign(const XlaOp& operand) { return UnaryOp(HloOpcode::kSign, operand); } @@ -1225,7 +1225,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, tensorflow::gtl::ArraySlice permutation) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferTransposeShape(operand_shape, permutation)); @@ -1613,13 +1613,35 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { return NoteErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); + auto b = CreateSubBuilder("sum"); + b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), + b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); + TF_ASSIGN_OR_RETURN(auto computation, b->Build()); + return CrossReplicaSum(operand, computation, /*replica_group_ids=*/{}, + /*channel_id=*/tensorflow::gtl::nullopt); + }); +} +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id) { + return NoteErrorOrReturn([&]() -> StatusOr { + if (!replica_group_ids.empty() || channel_id.has_value()) { + return Unimplemented( + "replica_group_ids and channel_id and is not supported in AllReduce"); + } + + HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + AddCalledComputation(computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, {operand}); }); @@ -1948,11 +1970,18 @@ StatusOr XlaBuilder::LookUpInstruction( const XlaOp& op) const { TF_RETURN_IF_ERROR(first_error_); + if (op.builder_ == nullptr) { + return InvalidArgument( + "invalid XlaOp with handle %lld; the builder of this op is freed", + op.handle()); + } if (op.builder_ != this) { - return InvalidArgument("invalid XlaOp with handle %lld", op.handle()); + return InvalidArgument( + "XlaOp with handle %lld is built by builder '%s', but is trying to use " + "it in builder '%s'", + op.handle(), op.builder_->name().c_str(), this->name().c_str()); } - TF_RET_CHECK(op.builder_ == this); if (op.handle() >= instructions_.size() || op.handle() < 0) { return InvalidArgument("no XlaOp value %lld", op.handle()); } diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 4955f1515d66af00ddf72e4c7621292a590e662c..532cae014848e17b24ee720a3c3dc5f99c89dfe5 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO(b/74197823): Replace computation_builder.h with this file. -// -// This is NOT YET ready to use. - #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ @@ -48,15 +44,11 @@ class XlaBuilder; // This represents an instruction that has been enqueued using the XlaBuilder. // This is used to pass to subsequent computations that depends upon the // instruction as an operand. -// -// TODO(b/74197823): Replace xla::ComputationDataHandle with this one. class XlaOp { public: XlaOp() : handle_(0), builder_(nullptr) {} ~XlaOp() {} - StatusOr GetShape() const; - const XlaBuilder* builder() const { return builder_; } bool operator==(const XlaOp& rhs) const { @@ -87,8 +79,6 @@ class XlaOp { // A convenient interface for building up computations. // // Thread-compatible. -// -// TODO(b/74197823): Replace xla::ComputationBuilder with this one. class XlaBuilder { public: // computation_name: name to use for the built computation. @@ -139,7 +129,7 @@ class XlaBuilder { // Enqueues a constant with the value of the given literal onto the // computation. - XlaOp ConstantLiteral(const Literal& literal); + XlaOp ConstantLiteral(const LiteralSlice& literal); // Enqueues a constant onto the computation. Methods are templated on the // native host type (NativeT) which corresponds to a specific XLA @@ -542,6 +532,29 @@ class XlaBuilder { // supply one input to the sum and all replicas receive the resulting sum. XlaOp CrossReplicaSum(const XlaOp& operand); + // Enqueues an operation that do an AllReduce of the operand cross cores. Here + // AllReduce means doing a reduction on the input operand cross cores and then + // broadcasting the reduction result to those cores. The reduction function is + // defined by `computation`, which should be a commutative computation on + // scalars, e.g., add, min, or max. The way that AllReduce is applied is + // configured by: + // + // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // - `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce when it's ready to use. + XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::nullopt); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -571,6 +584,9 @@ class XlaBuilder { // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); + // Enqueues an expm1 instruction onto the computation. + XlaOp Expm1(const XlaOp& operand); + // Enqueues a floor instruction onto the computation. XlaOp Floor(const XlaOp& operand); @@ -584,6 +600,9 @@ class XlaBuilder { // Enqueues an log instruction (natural logarithm) onto the computation. XlaOp Log(const XlaOp& operand); + // Enqueues an log1p instruction (log(x+1)) onto the computation. + XlaOp Log1p(const XlaOp& operand); + // Enqueues a sign instruction onto the computation. XlaOp Sign(const XlaOp& operand); @@ -847,6 +866,10 @@ class XlaBuilder { // computation and fills the root_id in the pointer. StatusOr GetProgramShape(int64* root_id) const; + // Returns shapes for the operands. + StatusOr> GetOperandShapes( + tensorflow::gtl::ArraySlice operands) const; + // A visitor which checks whether an operation is a compile-time constant, // meaning that it doesn't depend on any parameters, or on any stateful // operation such as `RngNormal` or `Infeed`. The visitor walks the @@ -981,8 +1004,6 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { // RAII-style object: sets the current sharding assignment in builder on // construction, and sets back to the previous assignment on destruction. -// -// TODO(b/74197823): This is a part of a NOT YET ready refactor. class XlaScopedShardingAssignment { public: XlaScopedShardingAssignment(xla::XlaBuilder* builder, diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc index ce984564d016ce65fa6c932f3cda290cc0d75a4a..2df3ea3af0d4fcfb9bc803feebd96f09042ab1f3 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -76,7 +76,7 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { auto y = b.Parameter(1, y_shape, "y"); auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); - TF_ASSERT_OK_AND_ASSIGN(auto add_shape, add.GetShape()); + TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add)); EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); @@ -188,8 +188,10 @@ TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { builder.Add(p0, p0); auto statusor = builder.Build(); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Do not add XlaOp from builder b1 to builder main")); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "built by builder 'b1', but is trying to use it in builder 'main'")); } TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h index b70b57e9ffec40188f246f5e884146012c02f4a2..0ffba208b1f8683fe1d26107cbfd096b856267f1 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -25,8 +25,6 @@ limitations under the License. namespace xla { // The computation graph that the user builds up with the XlaBuilder. -// -// TODO(b/74197823): Replace xla::Computation with this one. class XlaComputation { public: XlaComputation() : unique_id_(-1) {} diff --git a/tensorflow/compiler/xla/error_spec.h b/tensorflow/compiler/xla/error_spec.h new file mode 100644 index 0000000000000000000000000000000000000000..a1463aa15941b9c265db94e2eb3cc176fab6695b --- /dev/null +++ b/tensorflow/compiler/xla/error_spec.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ +#define TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ + +namespace xla { + +// Structure describing permissible absolute and relative error bounds. +struct ErrorSpec { + explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) + : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} + + float abs; // Absolute error bound. + float rel; // Relative error bound. + + // If relaxed_nans is true then any result is valid if we are expecting NaNs. + // In effect, this allows the tested operation to produce incorrect results + // for inputs outside its mathematical domain. + bool relaxed_nans; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index c6f8f6766e9d0156d0c68306af214443f584a9fe..e8f29b83291a7cb238dc25b9f4bb743fe426a162 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -65,6 +65,16 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice major_to_minor) { + Layout layout; + layout.set_format(DENSE); + for (int i = major_to_minor.size() - 1; i >= 0; i--) { + layout.add_minor_to_major(major_to_minor[i]); + } + return layout; +} + /* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { Layout layout; layout.set_format(SPARSE); @@ -88,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } // namespace /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { + if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) { + // Opaque and token types have empty layouts. + return Layout(); + } + // A Layout proto corresponds to a single array, not a tuple. - DCHECK(!ShapeUtil::IsTuple(shape)); + CHECK(ShapeUtil::IsArray(shape)); return CreateDefaultLayoutForRank(shape.dimensions_size()); } @@ -116,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) { SetToDefaultLayout(&element_shape); } shape->clear_layout(); - } else if (ShapeUtil::IsOpaque(*shape)) { - shape->clear_layout(); - } else { + } else if (ShapeUtil::IsArray(*shape)) { shape->mutable_layout()->set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major->Resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); + } else { + // Opaque, token types etc. have no layout. + shape->clear_layout(); } } @@ -140,8 +156,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::SetToDefaultLayout(program_shape->mutable_result()); } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutInShape( - const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutInShape(const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape. if (shape.has_layout()) { @@ -150,30 +165,34 @@ Layout CreateDefaultLayoutForRank(int64 rank) { for (auto& element_shape : shape.tuple_shapes()) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } - return tensorflow::Status::OK(); - } else if (ShapeUtil::IsOpaque(shape)) { - if (shape.has_layout()) { - return InvalidArgument("opaque should not have a layout field"); - } - return tensorflow::Status::OK(); - } else { - // Array shape. + return Status::OK(); + } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { return InvalidArgument("shape %s does not have a layout", ShapeUtil::HumanString(shape).c_str()); } return ValidateLayoutForShape(shape.layout(), shape); + } else { + // Token, opaque, etc. shape. + if (shape.has_layout()) { + return InvalidArgument( + "shape of primitive type %s should not have a layout", + PrimitiveType_Name(shape.element_type()).c_str()); + } + return Status::OK(); } } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutForShape( - const Layout& layout, const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout, + const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (ShapeUtil::IsOpaque(shape)) { - return tensorflow::Status::OK(); + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument( + "shape of primitive type %s should not have a layout", + PrimitiveType_Name(shape.element_type()).c_str()); } if (layout.format() == INVALID_FORMAT) { @@ -225,7 +244,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } - return tensorflow::Status::OK(); + return Status::OK(); } /* static */ void LayoutUtil::ClearLayout(Shape* shape) { @@ -264,7 +283,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsPadded(const Shape& shape) { - if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) || + if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) || shape.layout().padded_dimensions_size() == 0) { return false; } @@ -314,7 +333,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { // Tuple shape: all subshapes must have a layout. return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), [](const Shape& s) { return HasLayout(s); }); - } else if (ShapeUtil::IsOpaque(shape)) { + } else if (!ShapeUtil::IsArray(shape)) { + // Opaque, token types etc. ignore layout. return true; } return shape.has_layout() && shape.layout().format() != INVALID_FORMAT; @@ -384,7 +404,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { namespace { // Internal helper for recursively copying layouts. -tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { +Status CopyLayoutInternal(const Shape& src, Shape* dst) { if (ShapeUtil::IsTuple(src) != ShapeUtil::IsTuple(*dst)) { return InvalidArgument( "cannot copy layout from shape: shape structure differs"); @@ -411,25 +431,21 @@ tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { dst->clear_layout(); } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace /* static */ -tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, - Shape* dst) { +Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return CopyLayoutInternal(src, dst); } /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { - if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) { - return false; - } if (ShapeUtil::IsTuple(lhs)) { - if (ShapeUtil::TupleElementCount(lhs) != - ShapeUtil::TupleElementCount(rhs)) { + if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) != + ShapeUtil::TupleElementCount(rhs)) { return false; } for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { @@ -438,9 +454,12 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, } } return true; - } else { + } else if (ShapeUtil::IsArray(lhs)) { return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && LayoutUtil::Equal(lhs.layout(), rhs.layout()); + } else { + // Layouts of non-array and non-tuple shapes is ignored. + return true; } } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6cec7501015e2dff6b5e56e20b793a5458618501..739bbe73675c7fb855627006028eafdf703d6540 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -20,9 +20,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -36,6 +36,10 @@ class LayoutUtil { // convenience function for protobuf construction.) static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + // Similar to MakeLayout, but take indices in reverse order. + static Layout MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice major_to_minor); + // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) static Layout MakeSparseLayout(int64 max_sparse_elements); @@ -61,12 +65,12 @@ class LayoutUtil { static void SetToDefaultLayout(ProgramShape* program_shape); // Validates that the layout within the given shape is correct. - static tensorflow::Status ValidateLayoutInShape(const Shape& shape); + static Status ValidateLayoutInShape(const Shape& shape); // Validates that the provided layout satisfies invariants for the given // shape. - static tensorflow::Status ValidateLayoutForShape(const Layout& layout, - const Shape& shape); + static Status ValidateLayoutForShape(const Layout& layout, + const Shape& shape); // Clears the layout in the given Shape. After this function is called, // HasLayout will return false for the shape. @@ -179,8 +183,7 @@ class LayoutUtil { // tuples. 'src' and 'dst' need not be compatible but the two shapes must // have the same tuple structure (if any) and arrays must have the same // rank. within the shapes must have the same number of dimensions. - static tensorflow::Status CopyLayoutBetweenShapes(const Shape& src, - Shape* dst); + static Status CopyLayoutBetweenShapes(const Shape& src, Shape* dst); // Returns true if the layouts of lhs and rhs are equal, false // otherwise. Recursively compares layouts of tuples. diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 4fd1d818e3e3b417eee9f6b14bb598bfb9480c6e..e4c825450dcd45a8fbeaacbb2ad145f94307176f 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { "elements, but shape is rank")); } +TEST_F(LayoutUtilTest, CopyTokenLayout) { + Shape src = ShapeUtil::MakeTokenShape(); + Shape dst = ShapeUtil::MakeTokenShape(); + + // Layouts are trivially the same for token types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyOpaqueLayout) { + Shape src = ShapeUtil::MakeOpaqueShape(); + Shape dst = ShapeUtil::MakeOpaqueShape(); + + // Layouts are trivially the same for opaque types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) { + Shape src = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, ClearLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), @@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) { EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout()); } +TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) { + // Opaque and token types trivially have layouts. + for (Shape shape : + {ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) { + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + LayoutUtil::ClearLayout(&shape); + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + } +} + TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}), diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index bc8405703b02dc1b0c4c87005ea3c15372552157..f42fb92359f40ec763866af094972046f6407ae1 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -47,6 +47,12 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // Set cudnn batchnorm off by default; it does not provide a performance win // on average. flags->set_xla_gpu_use_cudnn_batchnorm(false); + + // Run all GPU work on one stream by default. Using multiple streams + // increases memory usage and we lack strong motivating benchmarks for tuning + // the heuristics needed to decide when to run on multiple streams. See + // b/77879207. + flags->set_xla_gpu_disable_multi_streaming(true); } // Allocates flag_values and flag_objects; this function must not be called more diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf9679cafec72c2e9dc5796e9058c6703239c508 --- /dev/null +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -0,0 +1,741 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_comparison.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" + +using tensorflow::strings::Appendf; +using tensorflow::strings::Printf; +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + +namespace xla { +namespace literal_comparison { +namespace { + +// Helper function for comparing a floating point type, FloatT, bitwise equal +// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT +// -- on miscompare, a nice error message is given in the AssertionFailure. +template +Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { + auto ulhs = tensorflow::bit_cast(lhs); + auto urhs = tensorflow::bit_cast(rhs); + auto lhs_double = static_cast(lhs); + auto rhs_double = static_cast(rhs); + if (ulhs != urhs) { + return InvalidArgument( + "floating values are not bitwise-equal; and equality testing " + "was requested: %s=%g=%a vs %s=%g=%a", + StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double, + StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double); + } + return Status::OK(); +} + +// Templated comparator that specializes for float equality comparison with the +// bitwise helper above (this is the un-specialized fallback, to just use the +// default gunit implementation). +template +Status CompareEqual(NativeT lhs, NativeT rhs) { + if (lhs == rhs) { + return Status::OK(); + } + return InvalidArgument("Expected equality of these values:\n %s\n %s", + StrCat(lhs).c_str(), StrCat(rhs).c_str()); +} + +// Specializations for floating types that do bitwise comparisons when equality +// comparison is requested. +template <> +Status CompareEqual(bfloat16 lhs, bfloat16 rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(Eigen::half lhs, Eigen::half rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(float lhs, float rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(double lhs, double rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(complex64 lhs, complex64 rhs) { + auto res = CompareEqual(lhs.real(), rhs.real()); + if (!res.ok()) { + return res; + } + return CompareEqual(lhs.imag(), rhs.imag()); +} + +// A recursive function which iterates through every index of expected and +// actual literal and compares their values elementwise. Returns true if all +// elements are equal. +template +Status Equal(LiteralSlice expected, LiteralSlice actual, + tensorflow::gtl::MutableArraySlice multi_index, + int64 dimension) { + if (dimension == expected.shape().dimensions_size()) { + NativeT expected_value = expected.Get(multi_index); + NativeT actual_value = actual.Get(multi_index); + return CompareEqual(expected_value, actual_value); + } + + Status result; + for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { + multi_index[dimension] = i; + result.Update(Equal(expected, actual, multi_index, dimension + 1)); + } + return result; +} + +// Gets the total element count. For tuples, this is not the count of tuple +// elements, but the sum of elements of each tuple element. +int64 RecursiveElementCount(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); + int64 total = 0; + for (int64 i = 0; i < tuple_elements; ++i) { + total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); + } + return total; + } else { + return ShapeUtil::ElementsIn(shape); + } +} + +// Returns whether the actual and expected values are mismatched with respect to +// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. +template +bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { + if (relaxed_nans) { + return !std::isnan(expected) && std::isnan(actual); + } else { + return std::isnan(expected) != std::isnan(actual); + } +} + +template <> +bool NanMismatch(complex64 expected, complex64 actual, + bool relaxed_nans) { + return NanMismatch(expected.real(), actual.real(), relaxed_nans) || + NanMismatch(expected.imag(), actual.imag(), relaxed_nans); +} + +template <> +bool NanMismatch(half expected, half actual, bool relaxed_nans) { + return NanMismatch(static_cast(expected), + static_cast(actual), relaxed_nans); +} + +// Converts the given floating-point value to a string. +template +string FpValueToString(NativeT value) { + return Printf("%8.4g", static_cast(value)); +} + +template <> +string FpValueToString(complex64 value) { + return Printf("%8.4g + %8.4fi", value.real(), value.imag()); +} + +// Returns the absolute value of the given floating point value. This function +// is used instead of std::abs directly in order to allow type-dependent +// implementations for NearComparator. +template +float FpAbsoluteValue(NativeT value) { + return std::abs(value); +} + +template <> +float FpAbsoluteValue(bfloat16 value) { + return FpAbsoluteValue(static_cast(value)); +} + +template <> +float FpAbsoluteValue(half value) { + return FpAbsoluteValue(static_cast(value)); +} + +// Helper class for comparing floating-point literals within an error bound. +template +class NearComparator { + public: + // Compares the two array literals elementwise and returns a comparison + // result. The comparison is ok() if all actual and expected elements are + // within the given error bound. In case of error, the status contains a + // detailed message about the discrepancy. + static Status Compare(const LiteralSlice& expected, + const LiteralSlice& actual, ErrorSpec error, + bool detailed_message, + const MiscompareCallback& miscompare_callback) { + NearComparator comparator(expected, actual, error, + detailed_message, miscompare_callback); + return comparator.Run(); + } + + private: + // Data structure encapsulating metadata about a single element mismatch. + struct Mismatch { + NativeT actual; + NativeT expected; + float rel_error; + float abs_error; + + // The linear index of the failure within the shape. This linear index is + // from the 'actual' literal. + int64 linear_index; + + bool operator<(const Mismatch& other) const { + return rel_error < other.rel_error; + } + + string ToString(const Shape& shape) const { + return Printf( + "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", + FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), + Literal::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex(shape, + linear_index)) + .c_str(), + rel_error, abs_error); + } + }; + + NearComparator(const LiteralSlice& expected, const LiteralSlice& actual, + ErrorSpec error, bool detailed_message, + const MiscompareCallback& miscompare_callback) + : expected_(expected), + actual_(actual), + error_(error), + detailed_message_(detailed_message), + miscompare_callback_(miscompare_callback), + abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}), + abs_error_buckets_(kErrorBucketBounds.size(), 0), + rel_error_buckets_(kErrorBucketBounds.size(), 0) {} + + // Runs the comparison between expected and actual literals. + Status Run() { + VLOG(1) << "expected:"; + XLA_VLOG_LINES(1, ToStringTruncated(expected_)); + VLOG(1) << "actual:"; + XLA_VLOG_LINES(1, ToStringTruncated(actual_)); + + // If the shapes mismatch, we simply fail the expectation instead of + // printing out data, as it's a type error rather than a value error. + TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); + if (!ShapeUtil::IsArray(expected_.shape())) { + return InvalidArgument("Expected array shape; got %s.", + ShapeUtil::HumanString(expected_.shape()).c_str()); + } + + mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); + mismatches_.PopulateWithValue(false); + + CompareLiterals(); + + if (num_mismatches_ == 0) { + return Status::OK(); + } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) { + miscompare_callback_(expected_, actual_, mismatches_); + } + return InvalidArgument("%s", ErrorMessage().c_str()); + } + + // Insert the given absolute value into the absolute value bucket vector. The + // bounds of the buckets are given by kAbsValueBucketBounds. + void UpdateAbsValueBucket(NativeT value, bool is_mismatch) { + // Adjust the bucket containing the absolute values of the 'actual' + // elements. + const float abs_value = FpAbsoluteValue(value); + for (int i = 0; i < abs_value_buckets_.size(); ++i) { + if (i == abs_value_buckets_.size() - 1 || + (abs_value >= kAbsValueBucketBounds[i] && + abs_value < kAbsValueBucketBounds[i + 1])) { + // The first value of the pair is the count of elements in the bucket, + // the second is the count of mismatches in the bucket. + abs_value_buckets_[i].first++; + if (is_mismatch) { + abs_value_buckets_[i].second++; + } + return; + } + } + } + + // Insert the given error into the given error bucket vector. + void UpdateErrorBucket( + float error, tensorflow::gtl::MutableArraySlice error_buckets) { + CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); + for (int i = 0; i < error_buckets.size(); ++i) { + if (error >= kErrorBucketBounds[i]) { + error_buckets[i]++; + } + } + } + + // Compares the two given elements from the expected and actual literals at + // the given literal_index and keeps track of various mismatch statistics. + void CompareValues(NativeT expected, NativeT actual, int64 linear_index) { + const bool is_nan_mismatch = + NanMismatch(expected, actual, error_.relaxed_nans); + float abs_error; + float rel_error; + if (actual == expected) { + abs_error = 0; + rel_error = 0; + } else if (is_nan_mismatch) { + num_nan_mismatches_++; + // A nan mismatch is considered to have infinite error. rel_error is used + // for sorting a std::set of the top mismatchs, and a nan value here will + // result in undefined behavior because nan's do not satisfy the strict + // weak ordering requirement of std containers. + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); + } else { + abs_error = FpAbsoluteValue(actual - expected); + rel_error = abs_error / FpAbsoluteValue(expected); + } + const bool is_abs_mismatch = abs_error > error_.abs; + const bool is_rel_mismatch = rel_error > error_.rel; + const bool is_mismatch = + is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); + + // Update the error of the relative bucket only if the *absolute* error + // bound is exceeded and vice versa. + if (is_abs_mismatch) { + num_abs_mismatches_++; + UpdateErrorBucket(rel_error, &rel_error_buckets_); + } + if (is_rel_mismatch) { + num_rel_mismatches_++; + UpdateErrorBucket(abs_error, &abs_error_buckets_); + } + + UpdateAbsValueBucket(actual, is_mismatch); + + if (!is_mismatch) { + return; + } + + num_mismatches_++; + + // Keep track of the kTopRelativeErrorCount relative error mismatches. + if (top_rel_mismatches_.size() < kTopRelativeErrorCount || + rel_error > top_rel_mismatches_.begin()->rel_error) { + Mismatch mismatch = {actual, expected, rel_error, abs_error, + linear_index}; + top_rel_mismatches_.insert(mismatch); + if (top_rel_mismatches_.size() > kTopRelativeErrorCount) { + top_rel_mismatches_.erase(top_rel_mismatches_.begin()); + } + } + + mismatches_.data()[linear_index] = true; + } + + // Compares the two literals elementwise. + void CompareLiterals() { + // Fast path optimization for the case were layouts match. + if (LayoutUtil::Equal(actual_.shape().layout(), + expected_.shape().layout())) { + tensorflow::gtl::ArraySlice expected_data = + expected_.data(); + tensorflow::gtl::ArraySlice actual_data = + actual_.data(); + const int64 len = expected_data.size(); + for (int64 i = 0; i < len; ++i) { + CompareValues(expected_data[i], actual_data[i], i); + } + return; + } + std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); + CompareLiteralsSlow(0, &multi_index); + } + + // Slow path for CompareLiterals when 'actual' and 'expected' literals have + // different layouts. In this case, multidimensional indices are constructed + // and indexed for each element. + void CompareLiteralsSlow(int64 dimension, std::vector* multi_index) { + if (dimension == multi_index->size()) { + CompareValues(expected_.Get(*multi_index), + actual_.Get(*multi_index), + IndexUtil::MultidimensionalIndexToLinearIndex( + actual_.shape(), *multi_index)); + } else { + for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) { + (*multi_index)[dimension] = i; + CompareLiteralsSlow(dimension + 1, multi_index); + } + } + } + + // Returns an error message string with a detailed breakdown of the + // mismatches. Called after calling Run(). + string ErrorMessage() { + string out; + int64 element_count = ShapeUtil::ElementsIn(actual_.shape()); + + auto percent_string = [](float a, float b) { + float pct = b == 0.0 ? 0.0 : 100.0 * a / b; + return Printf("%0.4f%%", pct); + }; + + Appendf(&out, + "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " + "%g, rel bound %g\n", + num_mismatches_, + percent_string(num_mismatches_, element_count).c_str(), + ShapeUtil::HumanString(actual_.shape()).c_str(), + ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); + if (num_nan_mismatches_ > 0) { + StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); + } + Appendf(&out, "Top relative error mismatches:\n"); + for (auto it = top_rel_mismatches_.rbegin(); + it != top_rel_mismatches_.rend(); ++it) { + StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); + } + + if (!detailed_message_) { + return out; + } + + StrAppend(&out, "Absolute magnitude breakdown of actual values:\n"); + CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size()); + for (int i = 0; i < abs_value_buckets_.size(); ++i) { + const int64 bucket_size = abs_value_buckets_[i].first; + const int64 bucket_mismatches = abs_value_buckets_[i].second; + string mismatch_str = bucket_mismatches > 0 + ? Printf(", mismatches %lld", bucket_mismatches) + : ""; + Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", + kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], + bucket_size, percent_string(bucket_size, element_count).c_str(), + mismatch_str.c_str()); + } + + auto print_accum_buckets = [&](const string& header, int64 total, + tensorflow::gtl::ArraySlice buckets) { + StrAppend(&out, header, ":\n"); + Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], + total - buckets[0], + percent_string(total - buckets[0], total).c_str()); + CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); + for (int i = 0; i < kErrorBucketBounds.size(); ++i) { + Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], + buckets[i], percent_string(buckets[i], total).c_str()); + } + }; + Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", + error_.abs, num_abs_mismatches_, + percent_string(num_abs_mismatches_, element_count).c_str()); + print_accum_buckets( + "Relative error breakdown of elements exceeding abs error bound", + num_abs_mismatches_, rel_error_buckets_); + Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", + error_.rel, num_rel_mismatches_, + percent_string(num_rel_mismatches_, element_count).c_str()); + print_accum_buckets( + "Absolute error breakdown of elements exceeding rel error bound", + num_rel_mismatches_, abs_error_buckets_); + return out; + } + + // 'actual' and 'expected' literals being compared. + LiteralSlice expected_; + LiteralSlice actual_; + + // The error bounds of the comparison. + ErrorSpec error_; + + // Whether to include detailed breakdown of mismatches in the error message. + bool detailed_message_; + + // Callback to invoke on miscompare. + MiscompareCallback miscompare_callback_; + + // Number of element element mismatches encountered so far. + int64 num_mismatches_ = 0; + + // Number of elements with a nan mismatch. + int64 num_nan_mismatches_ = 0; + + // Number of elements which exceed the absolute/relative error bound. + int64 num_abs_mismatches_ = 0; + int64 num_rel_mismatches_ = 0; + + // A Literal containing which elements did not match in the expected and + // actual literals. mismatches_ contains PREDs and is of the same sizes as + // the comparison literals. + Literal mismatches_; + + // The number of mismatches to report in the output, sorted by relative error + // magnitude. + static constexpr int64 kTopRelativeErrorCount = 5; + + // The set of mismatches with the largest relative error. The size of this set + // is bounded by kTopRelativeErrorCount. + std::multiset top_rel_mismatches_; + + // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the + // bounds of these buckets. abs_value_buckets_ contains a pair for each + // bucket: the element count and failure count. + static constexpr std::array kAbsValueBucketBounds = { + 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits::infinity()}; + std::vector> abs_value_buckets_; + + // Buckets for relative and absolute errors. The relative error buckets only + // contains those elements which exceed the *absolute* error bound, and vice + // versa. This makes it easy to see the effect of adjusting the relative (or + // absolute) error bound on the success of the comparison. kErrorBucketBounds + // are the lower bounds of the buckets in both vectors. The error buckets are + // a cumulative distribution so an error value may appear in more than one + // bucket. For example an error value of 0.003 may appear in the buckets + // bounded by 0.01, 0.1, and 1.0. + static constexpr std::array kErrorBucketBounds = {0.0001, 0.001, + 0.01, 0.1, 1}; + std::vector abs_error_buckets_; + std::vector rel_error_buckets_; +}; + +template +constexpr std::array NearComparator::kAbsValueBucketBounds; +template +constexpr std::array NearComparator::kErrorBucketBounds; + +// Helper function for comparing two literals for nearness. Handles tuple-shapes +// via recursion. shape_index is the ShapeIndex of expected (or actual) +// currently being compared. +Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback, + const ShapeIndex& shape_index) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + + if (ShapeUtil::IsTuple(expected.shape())) { + Status return_status; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + const auto expected_element = LiteralSlice(expected, {i}); + const auto actual_element = LiteralSlice(actual, {i}); + ShapeIndex element_index = shape_index; + element_index.push_back(i); + Status res = + NearHelper(expected_element, actual_element, error, detailed_message, + miscompare_callback, element_index); + if (!res.ok()) { + string err_message = Printf("\nArray at shape index %s%s", + element_index.ToString().c_str(), + res.error_message().c_str()); + if (return_status.ok()) { + return_status = res; + } else { + return_status = AppendStatus(return_status, res.error_message()); + } + } + } + if (!return_status.ok() && shape_index.empty()) { + // Emit a top-level error message containing the top-level shape in case + // of mismatch. + int64 total_elements = RecursiveElementCount(actual.shape()); + return_status = InvalidArgument( + "\nMismatches in shape %s (%lld elements):\n%s", + ShapeUtil::HumanString(actual.shape()).c_str(), total_elements, + return_status.error_message().c_str()); + } + return return_status; + } + + if (ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())) { + switch (expected.shape().element_type()) { + case BF16: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F16: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F32: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F64: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case C64: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + default: + LOG(FATAL) << "Unsupported primitive type in near comparator: " + << PrimitiveType_Name(expected.shape().element_type()) + << ". Must be floating-point type."; + } + } + + // Non-floating point literal. + return literal_comparison::Equal(expected, actual); +} + +} // namespace + +Status EqualShapes(const Shape& expected, const Shape& actual) { + if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { + return InvalidArgument("tupleness-mismatch! want: %s got %s", + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + if (ShapeUtil::IsTuple(expected)) { + if (ShapeUtil::TupleElementCount(expected) != + ShapeUtil::TupleElementCount(actual)) { + return InvalidArgument( + "want tuple element count: %lld got tuple element count: %lld", + ShapeUtil::TupleElementCount(expected), + ShapeUtil::TupleElementCount(actual)); + } + for (int i = 0; i < expected.tuple_shapes_size(); ++i) { + Status result = + EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + if (!result.ok()) { + return AppendStatus(result, StrCat("mismatch in tuple index", i)); + } + } + } else { + if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + return InvalidArgument("want rank of %s got rank of %s", + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + if (expected.element_type() != actual.element_type()) { + return InvalidArgument( + "mismatch in primitive type %s vs %s", + PrimitiveType_Name(expected.element_type()).c_str(), + PrimitiveType_Name(actual.element_type()).c_str()); + } + if (expected.dimensions_size() != actual.dimensions_size()) { + return InvalidArgument("want dimensions_size %d got dimensions_size %d", + expected.dimensions_size(), + actual.dimensions_size()); + } + for (int i = 0; i < expected.dimensions_size(); ++i) { + if (expected.dimensions(i) != actual.dimensions(i)) { + return InvalidArgument( + "mismatch in dimension #%d expected: %s actual: %s", i, + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + } + } + return Status::OK(); +} + +Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { + VLOG(1) << "expected:"; + XLA_VLOG_LINES(1, expected.ToString()); + VLOG(1) << "actual:"; + XLA_VLOG_LINES(1, actual.ToString()); + + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + std::vector multi_index(expected.shape().dimensions_size(), 0); + Status result; + switch (expected.shape().element_type()) { + case PRED: + result = Equal(expected, actual, &multi_index, 0); + break; + case U8: + result = Equal(expected, actual, &multi_index, 0); + break; + case S32: + result = Equal(expected, actual, &multi_index, 0); + break; + case S64: + result = Equal(expected, actual, &multi_index, 0); + break; + case U32: + result = Equal(expected, actual, &multi_index, 0); + break; + case U64: + result = Equal(expected, actual, &multi_index, 0); + break; + case BF16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F32: + result = Equal(expected, actual, &multi_index, 0); + break; + case F64: + result = Equal(expected, actual, &multi_index, 0); + break; + case C64: + result = Equal(expected, actual, &multi_index, 0); + break; + case TUPLE: { + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + result.Update( + Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}))); + } + break; + } + default: + LOG(FATAL) + << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " + << PrimitiveType_Name(expected.shape().element_type()); + } + + if (result.ok()) { + return Status::OK(); + } + + return AppendStatus(result, + tensorflow::strings::Printf( + "\nat index: %s\nexpected: %s\nactual: %s", + Literal::MultiIndexAsString(multi_index).c_str(), + ToStringTruncated(expected).c_str(), + ToStringTruncated(actual).c_str())); +} + +Status Near(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback) { + return NearHelper(expected, actual, error, detailed_message, + miscompare_callback, + /*shape_index=*/{}); +} + +string ToStringTruncated(const LiteralSlice& literal) { + return RecursiveElementCount(literal.shape()) < 1000 + ? literal.ToString() + : "[TRUNCATED, Literal with more than 1000 values]"; +} + +} // namespace literal_comparison +} // namespace xla diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h new file mode 100644 index 0000000000000000000000000000000000000000..00a13e361932e74a9a1e614d5c851d3851208852 --- /dev/null +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Library for comparing literals without taking a dependency on testing +// libraries. + +#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ +#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ + +#include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { +namespace literal_comparison { + +// Returns ok if the given shapes have the same rank, dimension sizes, and +// primitive types. +Status EqualShapes(const Shape& expected, const Shape& actual); + +// Returns ok if the expected and actual literals are (bitwise) equal for all +// elements in the literal. Also, asserts that the rank, dimensions sizes, and +// primitive type are equal. +Status Equal(const LiteralSlice& expected, const LiteralSlice& actual); + +using MiscompareCallback = + std::function; + +// Inspects whether the expected and actual literals are within the given error +// bound for all elements. Also, inspects whether the rank, dimensions sizes, +// and dimension bounds are equivalent. +// +// Tuples are matched recursively. +// +// When comparing tensors of non-floating-point type, this inspects for exact +// equality, ignoring the ErrorSpec. +// +// If the shape of the literals is neither a complex/floating-point tensor nor a +// tuple which contains a complex/floating-point tensor, Near() is equivalent to +// Equal(). We don't raise an error in this case, because we want to allow +// callers to call Near() even if they have no preconceptions about the shapes +// being compared. +// +// If detailed_message is true, then the error message in the assertion result +// will contain a more detailed breakdown of mismatches. +Status Near(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback); + +// Calling ToString on a literal with over 100 million elements takes around +// 3 minutes. The utility of printing a literal with >1000 elements is +// questionable, especially when writing the Literal proto to disk is orders +// of magnitude faster. +string ToStringTruncated(const LiteralSlice& literal); + +} // namespace literal_comparison +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index b3b5e34ba220c7e9bf1cefef4b27baa6faee2c20..6b295897004cebce003ddd3999aacf63915ffe5f 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -62,8 +62,49 @@ void ConvertEndianShort(char* bytes, int64 size) { } } +// Return a literal with all arrays of type FromNativeT converted to type +// ToNativeT in the given literal. +template +std::unique_ptr ConvertType(LiteralSlice literal) { + // First construct shape of the result. + Shape result_shape(literal.shape()); + ShapeUtil::ForEachMutableSubshape( + &result_shape, [](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == + primitive_util::NativeToPrimitiveType()) { + subshape->set_element_type( + primitive_util::NativeToPrimitiveType()); + } + }); + auto result = MakeUnique(result_shape); + + // Then copy over the data from 'literal' converting FromNativeT values to + // ToNativeT values as necessary. + ShapeUtil::ForEachSubshape( + literal.shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + if (subshape.element_type() == + primitive_util::NativeToPrimitiveType()) { + auto src = literal.data(shape_index); + auto dest = result->data(shape_index); + for (int64 i = 0; i < src.size(); ++i) { + dest[i] = static_cast(src[i]); + } + } else { + TF_CHECK_OK(result->CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); + } + } + }); + return result; +} + } // namespace +LiteralBase::~LiteralBase() {} + std::ostream& operator<<(std::ostream& out, const Literal& literal) { out << literal.ToString(); return out; @@ -95,99 +136,89 @@ Literal::StrideConfig::StrideConfig( Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} -Literal::Literal(const Shape& shape, bool allocate_arrays) - : shape_(shape), pieces_(shape), owns_buffers_(true) { - CHECK(LayoutUtil::HasLayout(shape)); - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - const Shape& subshape = piece.subshape(); - if (ShapeUtil::IsArray(subshape)) { - if (allocate_arrays) { - if (LayoutUtil::IsSparseArray(subshape)) { - // For sparse arrays, the buffer must be of the size of the maximum - // number of sparse elements possible. - const int64 max_sparse_elements = - LayoutUtil::MaxSparseElements(subshape.layout()); - piece.set_buffer( - new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType( - subshape.element_type())]); - piece.set_sparse_indices(new SparseIndexArray( - max_sparse_elements, ShapeUtil::Rank(subshape))); - } else { - piece.set_buffer(new char[piece.size_bytes()]); - } +void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + SetPiece(subshape, &child_piece, allocate_arrays); + + piece->emplace_back(std::move(child_piece)); + } + } else { + CHECK(ShapeUtil::IsArray(shape)); + if (allocate_arrays) { + if (LayoutUtil::IsSparseArray(shape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(shape.layout()); + piece->set_buffer( + new char[max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_sparse_indices( + new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); } else { - piece.set_buffer(nullptr); + piece->set_buffer(new char[piece->size_bytes()]); } } } } -Literal::~Literal() { DeallocateBuffers(); } +Literal::Literal(const Shape& shape, bool allocate_arrays) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(LayoutUtil::HasLayout(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + CHECK(&root_piece_->subshape() == shape_.get()); -void Literal::DeallocateBuffers() { - if (owns_buffers_) { - for (auto& pair : pieces_) { - Piece& piece = pair.second; - if (piece.buffer() != nullptr) { - delete[] piece.buffer(); - delete piece.sparse_indices(); - } - } - } + SetPiece(*shape_, root_piece_, allocate_arrays); } -Literal::Literal(Literal&& other) { - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); +Literal::~Literal() { + if (root_piece_ != nullptr) { + DeallocateBuffers(); + delete root_piece_; } - owns_buffers_ = other.owns_buffers_; +} - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); +void Literal::DeallocateBuffers() { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete[] piece->buffer(); + delete piece->sparse_indices(); + } + }); } +Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } + Literal& Literal::operator=(Literal&& other) { - DeallocateBuffers(); - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } - owns_buffers_ = other.owns_buffers_; - - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); + DCHECK(&other.root_piece_->subshape() == other.shape_.get()); + using std::swap; + swap(shape_, other.shape_); + swap(root_piece_, other.root_piece_); + DCHECK(&root_piece_->subshape() == shape_.get()); + return *this; } -std::unique_ptr Literal::CreateFromShape(const Shape& shape) { +std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { auto literal = MakeUnique(shape); - for (auto& pair : literal->pieces_) { - Piece& piece = pair.second; - if (ShapeUtil::IsArray(piece.subshape())) { - memset(piece.untyped_data(), 0, piece.size_bytes()); - } - } + literal->root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (ShapeUtil::IsArray(piece->subshape())) { + memset(piece->untyped_data(), 0, piece->size_bytes()); + } + }); return literal; } -const SparseIndexArray* Literal::sparse_indices( +const SparseIndexArray* LiteralBase::sparse_indices( const ShapeIndex& shape_index) const { return piece(shape_index).sparse_indices(); } @@ -202,9 +233,19 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); } +/* static */ std::unique_ptr Literal::ConvertBF16ToF32( + const LiteralSlice& bf16_literal) { + return ConvertType(bf16_literal); +} + +/* static */ std::unique_ptr Literal::ConvertF32ToBF16( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + template Status Literal::CopySliceFromInternal( - const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); @@ -264,7 +305,7 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } -Status Literal::CopyElementFrom(const Literal& src_literal, +Status Literal::CopyElementFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice src_index, tensorflow::gtl::ArraySlice dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); @@ -293,22 +334,21 @@ std::vector Literal::DecomposeTuple() { elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), /*allocate_arrays=*/false)); Literal& element = elements.back(); - for (auto& pair : element.pieces_) { - const ShapeIndex& index = pair.first; - Piece& dest_piece = pair.second; - ShapeIndex src_index = {i}; - for (int64 j : index) { - src_index.push_back(j); - } - Piece& src_piece = piece(src_index); - - // Move the respective buffer and sparse indices over to the element - // Literal. - dest_piece.set_buffer(src_piece.buffer()); - src_piece.set_buffer(nullptr); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - src_piece.set_sparse_indices(nullptr); - } + element.root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* dest_piece) { + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece->set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece->set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + }); } // Set this literal to be nil-shaped. *this = Literal(); @@ -351,7 +391,9 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFrom(const Literal::Piece& src) { +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { + CHECK(subshape_ != nullptr); + CHECK(src.subshape_ != nullptr); if (ShapeUtil::Equal(subshape(), src.subshape())) { // If the layouts are equal it's faster just to memcpy. memcpy(buffer(), src.buffer(), src.size_bytes()); @@ -388,7 +430,7 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) { return Status::OK(); } -Status Literal::CopyFrom(const Literal& src_literal, +Status Literal::CopyFrom(const LiteralSlice& src_literal, const ShapeIndex& dest_shape_index, const ShapeIndex& src_shape_index) { const Shape& dest_subshape = @@ -401,36 +443,32 @@ Status Literal::CopyFrom(const Literal& src_literal, ShapeUtil::HumanString(dest_subshape).c_str(), ShapeUtil::HumanString(src_subshape).c_str()); } + return root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + if (!ShapeUtil::IsArray(piece->subshape())) { + return Status::OK(); + } - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Determine if this index is in the part of this literal that we want to - // copy over from src_literal. - bool in_subtree_to_copy = true; - for (int i = 0; i < dest_shape_index.size(); ++i) { - if (index[i] != dest_shape_index[i]) { - in_subtree_to_copy = false; - break; - } - } - if (!in_subtree_to_copy) { - continue; - } - - // Construct the index of the corresponding piece in the source literal. - ShapeIndex src_piece_index = src_shape_index; - for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { - src_piece_index.push_back(index[i]); - } - - TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index))); - } - return Status::OK(); + // Determine if this index is in the part of this literal that we want + // to copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + return Status::OK(); + } + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + return Status::OK(); + }); } Status Literal::MoveFrom(Literal&& src_literal, @@ -444,37 +482,32 @@ Status Literal::MoveFrom(Literal&& src_literal, ShapeUtil::HumanString(src_literal.shape()).c_str()); } - if (!(owns_buffers_ && src_literal.owns_buffers_)) { - return InvalidArgument( - "Source and destination literals must both own their buffers (ie, not " - "be views)"); - } + src_literal.root_piece_->ForEachSubpiece( + [&](const ShapeIndex& src_index, const Piece& src_piece) { + if (!ShapeUtil::IsArray(src_piece.subshape())) { + return; + } - for (auto& pair : src_literal.pieces_) { - const ShapeIndex& src_index = pair.first; - Piece& src_piece = pair.second; - if (!ShapeUtil::IsArray(src_piece.subshape())) { - continue; - } + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + }); - ShapeIndex dest_index = dest_shape_index; - for (int64 i : src_index) { - dest_index.push_back(i); - } - Piece& dest_piece = piece(dest_index); - delete[] dest_piece.buffer(); - dest_piece.set_buffer(src_piece.buffer()); - delete dest_piece.sparse_indices(); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - } + src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + delete src_literal.root_piece_; + src_literal.root_piece_ = new LiteralBase::Piece(); + src_literal.root_piece_->set_subshape(src_literal.shape_.get()); - src_literal.shape_ = ShapeUtil::MakeNil(); - src_literal.pieces_ = ShapeTree(src_literal.shape_); - src_literal.piece({}).set_subshape(&src_literal.shape_); return Status::OK(); } -Status Literal::CopySliceFrom(const Literal& src_literal, +Status Literal::CopySliceFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { @@ -743,7 +776,7 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { return CreateR2FromArray2D(*value); } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Layout& new_layout, const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); @@ -755,7 +788,7 @@ std::unique_ptr Literal::Relayout( return result; } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) @@ -774,7 +807,48 @@ std::unique_ptr Literal::Relayout( return result; } -StatusOr> Literal::Reshape( +StatusOr> LiteralBase::Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Broadcast only supports arrays."); + } + + for (int64 i = 0; i < dimensions.size(); i++) { + TF_RET_CHECK(shape().dimensions(i) == + result_shape.dimensions(dimensions[i])); + } + + std::unique_ptr result = MakeUnique(result_shape); + + // scratch_source_index is temporary storage space for the computed index into + // the input literal. We put it here to avoid allocating an std::vector in + // every iteration of ShapeUtil::ForEachIndex. + std::vector scratch_source_index(shape().dimensions_size()); + + char* dest_data = static_cast(result->untyped_data()); + const char* source_data = static_cast(untyped_data()); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + ShapeUtil::ForEachIndex( + result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + for (int64 i = 0; i < dimensions.size(); ++i) { + scratch_source_index[i] = output_index[dimensions[i]]; + } + int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( + result_shape, output_index); + int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( + shape(), scratch_source_index); + memcpy(dest_data + primitive_size * dest_index, + source_data + primitive_size * source_index, primitive_size); + return true; + }); + + return std::move(result); +} + +StatusOr> LiteralBase::Reshape( tensorflow::gtl::ArraySlice dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); @@ -788,7 +862,8 @@ StatusOr> Literal::Reshape( } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions); + *output->mutable_shape_do_not_use() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -802,7 +877,79 @@ StatusOr> Literal::Reshape( return std::move(output); } -std::unique_ptr Literal::Transpose( +/* static */ std::unique_ptr Literal::ReshapeSlice( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, + const LiteralSlice& literal) { + int64 new_num_elements = 1; + for (int64 i = 0; i < new_dimensions.size(); ++i) { + new_num_elements *= new_dimensions[i]; + } + CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); + CHECK_EQ(new_dimensions.size(), minor_to_major.size()); + + auto new_literal = MakeUnique( + ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); + + // Create a new shape with the given minor-to-major layout. This shape is used + // solely for converting linear address to multi-dimensional addresses when + // writing elements to the new literal. + Shape shape_with_layout = new_literal->shape(); + *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); + + // Copy data into new literal, element-by-element. + for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { + std::vector from_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); + std::vector to_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); + switch (literal.shape().element_type()) { + case PRED: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U8: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case S32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case S64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case F32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case F64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case C64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + default: + LOG(FATAL) << "Unhandled primitive element type: " + << PrimitiveType_Name(literal.shape().element_type()); + } + } + + return new_literal; +} + +std::unique_ptr LiteralBase::Transpose( tensorflow::gtl::ArraySlice permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) @@ -833,15 +980,31 @@ std::unique_ptr Literal::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - std::unique_ptr new_literal = CreateFromShape(permuted_shape); - DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), + auto new_literal = MakeUnique(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(), - root_piece().size_bytes()); + std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); return new_literal; } -std::unique_ptr Literal::Slice( +template +std::unique_ptr LiteralBase::SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const { + auto result_literal = MakeUnique(result_shape); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + result_literal->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + NativeT value = Get(new_indices); + result_literal->Set(indices, value); + }); + return result_literal; +} + +std::unique_ptr LiteralBase::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; @@ -858,71 +1021,37 @@ std::unique_ptr Literal::Slice( const auto result_shape = ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); - - auto result_literal = MakeUnique(result_shape); - - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { case F32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - float value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); + case BF16: + return SliceInternal(result_shape, start_indices); case C64: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, complex64 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - complex64 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case S32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - int32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case U32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, uint32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - uint32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); } } -Literal Literal::Clone() const { +Literal LiteralBase::Clone() const { Literal result(shape()); TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr Literal::CloneToUnique() const { +std::unique_ptr LiteralBase::CloneToUnique() const { auto result = MakeUnique(shape()); TF_CHECK_OK(result->CopyFrom(*this)); return result; } -string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); switch (subshape.element_type()) { @@ -962,8 +1091,8 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, } } -string Literal::GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +string LiteralBase::GetSparseElementAsString( + int64 sparse_element_number, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsSparseArray(subshape)); switch (subshape.element_type()) { @@ -1017,7 +1146,7 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number, } } -StatusOr Literal::GetIntegralAsS64( +StatusOr LiteralBase::GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { @@ -1040,6 +1169,27 @@ StatusOr Literal::GetIntegralAsS64( } } +size_t LiteralBase::Hash() const { + using tensorflow::Hash64; + using tensorflow::Hash64Combine; + + size_t hash_value = ShapeUtil::Hash(shape()); + + ShapeUtil::ForEachSubshape( + shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsTuple(subshape)) { + return; + } + + CHECK(LayoutUtil::IsDense(subshape.layout())); + hash_value = Hash64Combine( + hash_value, Hash64(static_cast(untyped_data(index)), + size_bytes(index))); + }); + + return hash_value; +} + Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, int64 value) { CHECK(LayoutUtil::IsDenseArray(shape())); @@ -1070,7 +1220,7 @@ Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, return Status::OK(); } -tensorflow::gtl::ArraySlice Literal::GetSparseIndex( +tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); CHECK_GE(sparse_element_number, 0); @@ -1082,10 +1232,10 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } -Literal Literal::GetFirstScalarLiteral() const { - CHECK(ShapeUtil::IsArray(shape_)); - CHECK_GT(ShapeUtil::ElementsIn(shape_), 0); - switch (shape_.element_type()) { +Literal LiteralBase::GetFirstScalarLiteral() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_GT(ShapeUtil::ElementsIn(shape()), 0); + switch (shape().element_type()) { case PRED: return std::move(*Literal::CreateR0(GetFirstElement())); // 8 bit types. @@ -1121,11 +1271,11 @@ Literal Literal::GetFirstScalarLiteral() const { case U64: return std::move(*Literal::CreateR0(GetFirstElement())); default: - LOG(FATAL) << "Unhandled primitive type " << shape_.element_type(); + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); } } -void Literal::Piece::SortSparseElements() { +void LiteralBase::Piece::SortSparseElements() { switch (subshape().element_type()) { case PRED: SortSparseElementsInternal(); @@ -1176,7 +1326,7 @@ void Literal::Piece::SortSparseElements() { } template -void Literal::Piece::SortSparseElementsInternal() { +void LiteralBase::Piece::SortSparseElementsInternal() { CHECK(LayoutUtil::IsSparseArray(subshape())); int64 num_elements = sparse_indices()->index_count(); auto values = data(); @@ -1187,9 +1337,11 @@ void Literal::Piece::SortSparseElementsInternal() { namespace { -void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); auto shape_to_string = [print_layout](const Shape& shape) { if (print_layout) { @@ -1348,13 +1500,14 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, } // namespace -int64 Literal::sparse_element_count() const { +int64 LiteralBase::sparse_element_count() const { CHECK(LayoutUtil::IsSparseArray(shape())); return sparse_indices()->index_count(); } -string Literal::ToString(bool print_layout) const { +string LiteralBase::ToString(bool print_layout) const { std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); ToStringHelper(*this, {}, print_layout, &pieces); return tensorflow::str_util::Join(pieces, ""); } @@ -1362,7 +1515,7 @@ string Literal::ToString(bool print_layout) const { /* static */ std::unique_ptr Literal::MakeTuple( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; - for (const Literal* element : elements) { + for (const auto* element : elements) { element_shapes.push_back(element->shape()); } auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); @@ -1372,6 +1525,19 @@ string Literal::ToString(bool print_layout) const { return literal; } +/* static */ std::unique_ptr Literal::MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements) { + std::vector element_shapes; + for (const auto& element : elements) { + element_shapes.push_back(element.shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + } + return literal; +} + /* static */ std::unique_ptr Literal::MakeTupleOwned( std::vector> elements) { std::vector element_shapes; @@ -1387,7 +1553,7 @@ string Literal::ToString(bool print_layout) const { return literal; } -void Literal::EachCellAsString( +void LiteralBase::EachCellAsString( const std::function indices, const string& value)>& per_cell) const { if (ShapeUtil::HasZeroElements(shape())) { @@ -1403,7 +1569,7 @@ void Literal::EachCellAsString( namespace { template std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const Literal& src_literal, const ConverterType& converter) { + const LiteralBase& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( src_literal.shape(), @@ -1419,7 +1585,8 @@ std::unique_ptr ConvertBetweenNativeTypesWithConverter( } template -std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { +std::unique_ptr ConvertBetweenNativeTypes( + const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1428,7 +1595,7 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { template typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), std::unique_ptr>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast(src); }; @@ -1443,12 +1610,12 @@ BitcastBetweenNativeTypes(const Literal& src_literal) { template typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), std::unique_ptr>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template -std::unique_ptr ConvertToC64(const Literal& src_literal) { +std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); @@ -1466,7 +1633,7 @@ std::unique_ptr ConvertToC64(const Literal& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal, +std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { @@ -1486,7 +1653,7 @@ std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal, template StatusOr> ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type, + const LiteralBase& src_literal, PrimitiveType primitive_dest_type, bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ @@ -1521,7 +1688,8 @@ StatusOr> ConvertIfDestTypeMatches( } StatusOr> ConvertSwitch( - const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) { + const LiteralBase& literal, PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { return literal.CloneToUnique(); @@ -1555,12 +1723,12 @@ StatusOr> ConvertSwitch( } // namespace -StatusOr> Literal::Convert( +StatusOr> LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr> Literal::BitcastConvert( +StatusOr> LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { @@ -1575,7 +1743,7 @@ StatusOr> Literal::BitcastConvert( return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr> Literal::ConvertToShape( +StatusOr> LiteralBase::ConvertToShape( const Shape& dest_shape, bool round_f32_to_bf16) const { if (!ShapeUtil::IsTuple(dest_shape)) { if (round_f32_to_bf16 && shape().element_type() == F32 && @@ -1590,7 +1758,7 @@ StatusOr> Literal::ConvertToShape( } std::vector elements; for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - auto element = LiteralView::Create(*this, {i}); + auto element = LiteralSlice(*this, {i}); TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); @@ -1602,8 +1770,8 @@ StatusOr> Literal::ConvertToShape( } template -bool Literal::Piece::EqualElementsInternal( - const Literal::Piece& other, std::vector* multi_index) const { +bool LiteralBase::Piece::EqualElementsInternal( + const LiteralBase::Piece& other, std::vector* multi_index) const { if (multi_index->size() == ShapeUtil::Rank(subshape())) { return (Get(*multi_index) == other.Get(*multi_index)); } @@ -1617,7 +1785,7 @@ bool Literal::Piece::EqualElementsInternal( return true; } -bool Literal::Piece::EqualElements(const Literal::Piece& other) const { +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); std::vector multi_index; @@ -1645,28 +1813,28 @@ bool Literal::Piece::EqualElements(const Literal::Piece& other) const { case C64: return EqualElementsInternal(other, &multi_index); default: - LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type " + LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " << PrimitiveType_Name(subshape().element_type()); } } -bool Literal::operator==(const Literal& other) const { +bool LiteralBase::operator==(const LiteralBase& other) const { if (!ShapeUtil::Compatible(shape(), other.shape())) { return false; } - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - const Piece& other_piece = other.piece(index); - if (!piece.EqualElements(other_piece)) { - return false; - } - } - return true; + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; + } + return true; + }); } namespace { @@ -1684,11 +1852,11 @@ static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, } // namespace -bool Literal::IsAll(int8 value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; +bool LiteralBase::IsAll(int8 value) const { + return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, + const Piece& piece) { if (!ShapeUtil::IsArray(piece.subshape())) { - continue; + return true; } auto piece_is_all = [&]() { @@ -1741,41 +1909,41 @@ bool Literal::IsAll(int8 value) const { if (!piece_is_all()) { return false; } - } - return true; + return true; + }); } -bool Literal::IsAllFloat(float value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } +bool LiteralBase::IsAllFloat(float value) const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - default: + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; + } + }; + if (!piece_is_all()) { return false; - } - }; - if (!piece_is_all()) { - return false; - } - } - return true; + } + return true; + }); } -bool Literal::IsAllComplex(complex64 value) const { +bool LiteralBase::IsAllComplex(complex64 value) const { switch (shape().element_type()) { case C64: return AllElementsEqualValue(root_piece().data(), @@ -1785,93 +1953,93 @@ bool Literal::IsAllComplex(complex64 value) const { } } -bool Literal::IsAllFirst() const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Empty shapes are not all the first element since there is no first - // element. - if (ShapeUtil::HasZeroElements(piece.subshape())) { - return false; - } - auto piece_is_all = [&]() { - switch (piece.subshape().element_type()) { - case PRED: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 8 bit types - case S8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 16 bit types - case BF16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 32 bit types - case F32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 64 bit types - case C64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); +bool LiteralBase::IsAllFirst() const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; } - default: + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::HasZeroElements(piece.subshape())) { return false; - } - }; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + default: + return false; + } + }; - if (!piece_is_all()) { - return false; - } - } - return true; + if (!piece_is_all()) { + return false; + } + return true; + }); } -bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { +bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: @@ -1913,7 +2081,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest, } // namespace -void Literal::Piece::WriteToProto(LiteralProto* proto) const { +void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { *proto->mutable_shape() = subshape(); switch (subshape().element_type()) { case PRED: @@ -1969,12 +2137,12 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const { } } -const void* Literal::Piece::untyped_data() const { +const void* LiteralBase::Piece::untyped_data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } -void* Literal::Piece::untyped_data() { +void* LiteralBase::Piece::untyped_data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } @@ -1995,7 +2163,7 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { +Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { // These conditions should have been checked in Literal::CreateFromProto. TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); @@ -2062,21 +2230,19 @@ Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { return Status::OK(); } -LiteralProto Literal::ToProto() const { +LiteralProto LiteralBase::ToProto() const { LiteralProto proto; - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - - LiteralProto* proto_piece = &proto; - for (int64 i : index) { - while (proto_piece->tuple_literals_size() <= i) { - proto_piece->add_tuple_literals(); - } - proto_piece = proto_piece->mutable_tuple_literals(i); - } - piece.WriteToProto(proto_piece); - } + root_piece().ForEachSubpiece( + [&](const ShapeIndex& index, const Piece& piece) { + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + }); if (LayoutUtil::IsSparseArray(shape())) { CopyToRepeatedField(proto.mutable_sparse_indices(), @@ -2098,33 +2264,40 @@ StatusOr> Literal::CreateFromProto( auto literal = MakeUnique(proto.shape()); - for (auto& pair : literal->pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - const LiteralProto* proto_element = &proto; - for (int64 i : index) { - TF_RET_CHECK(i < proto_element->tuple_literals_size()); - proto_element = &proto_element->tuple_literals(i); - } + TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } - if (ShapeUtil::IsTuple(piece.subshape())) { - if (proto_element->tuple_literals_size() != - ShapeUtil::TupleElementCount(piece.subshape())) { - return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", - ShapeUtil::TupleElementCount(piece.subshape()), - proto_element->tuple_literals_size()); - } - continue; - } + if (ShapeUtil::IsTuple(piece->subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece->subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece->subshape()), + proto_element->tuple_literals_size()); + } + return Status::OK(); + } + + CHECK(ShapeUtil::IsArray(piece->subshape())); + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + return Status::OK(); + })); - TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape())); - TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element)); - } return std::move(literal); } -const void* Literal::untyped_data(const ShapeIndex& shape_index) const { +/* static */ string Literal::MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index) { + return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); +} + +const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { return piece(shape_index).untyped_data(); } @@ -2132,11 +2305,11 @@ void* Literal::untyped_data(const ShapeIndex& shape_index) { return piece(shape_index).untyped_data(); } -int64 Literal::size_bytes(const ShapeIndex& shape_index) const { +int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { return piece(shape_index).size_bytes(); } -string Literal::GetR1U8AsString() const { +string LiteralBase::GetR1U8AsString() const { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(shape().element_type(), U8); @@ -2144,72 +2317,55 @@ string Literal::GetR1U8AsString() const { ShapeUtil::ElementsIn(shape())); } -/* static */ const LiteralView LiteralView::Create( - const Literal& literal, const ShapeIndex& view_root) { - return LiteralView(literal, view_root); -} +void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { + CHECK(ShapeUtil::IsTuple(shape)); + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); -size_t Literal::Hash() const { - using tensorflow::Hash64; - using tensorflow::Hash64Combine; + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); - size_t hash_value = ShapeUtil::Hash(shape()); - - ShapeUtil::ForEachSubshape( - shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsTuple(subshape)) { - return; - } - - CHECK(LayoutUtil::IsDense(subshape.layout())); - hash_value = Hash64Combine( - hash_value, Hash64(static_cast(untyped_data(index)), - size_bytes(index))); - }); - - return hash_value; -} - -LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) { - shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root); - pieces_ = ShapeTree(shape_); - owns_buffers_ = false; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - - ShapeIndex src_index = view_root; - for (int64 i : index) { - src_index.push_back(i); + if (ShapeUtil::IsTuple(subshape)) { + BuildPieceSubtree(subshape, &child_piece); } - const Piece& src_piece = literal.piece(src_index); - piece.set_buffer(src_piece.buffer()); - piece.set_sparse_indices(src_piece.sparse_indices()); - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + + piece->emplace_back(std::move(child_piece)); } } -LiteralView::~LiteralView() {} +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} -LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); } +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} -LiteralView& LiteralView::operator=(const LiteralView& other) { - CopyFrom(other); - return *this; +BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsArray(*shape_)); + CHECK_NE(src_buf_ptr, nullptr); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = Piece(); + root_piece_.set_buffer(const_cast(src_buf_ptr)); + root_piece_.set_subshape(shape_.get()); } -void LiteralView::CopyFrom(const LiteralView& other) { - // We can't use the default copy-constructor/copy-assignment because - // Piece::subshape_ points to subshapes within the Shape of the owning - // Literal/LiteralView. - shape_ = other.shape(); - pieces_ = other.pieces_; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); +BorrowingLiteral::BorrowingLiteral( + tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); + root_piece_ = Piece(); + root_piece_.set_subshape(shape_.get()); + BuildPieceSubtree(*shape_, &root_piece_); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + const auto& src_shape = shape_->tuple_shapes(i); + CHECK(ShapeUtil::IsArray(src_shape)); + root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } - owns_buffers_ = false; } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index c6bd03bf21ac8dc88e96856cffe02c758e7b996d..8e4159e360e042beb31a75c432a3c7dfa7356007 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -52,14 +51,509 @@ limitations under the License. namespace xla { +// Forward declare Literal and LiteralSlice class to be used by the creation +// methods in the base class. +class Literal; +class LiteralSlice; + +// Abstract base class for literals. +class LiteralBase { + public: + virtual ~LiteralBase() = 0; + + // Literals are equal if they have compatible shapes and the same data + // values. Layout is not compared. + bool operator==(const LiteralBase& other) const; + bool operator!=(const LiteralBase& other) const { return !(*this == other); } + + // Returns the shape of the literal. + const Shape& shape() const { return root_piece().subshape(); } + + // Serialize to proto. + LiteralProto ToProto() const; + + // Returns an ArraySlice of the array for this literal for the given NativeT + // (e.g., float). CHECKs if the subshape of the literal at the given + // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type + // to native type. + template + tensorflow::gtl::ArraySlice data( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to the sparse index array. Returns nullptr if the + // literal is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to (or size of) the underlying buffer holding the + // array at the given shape index. CHECKs if the subshape of the literal at + // the given ShapeIndex is not array. + const void* untyped_data(const ShapeIndex& shape_index = {}) const; + int64 size_bytes(const ShapeIndex& shape_index = {}) const; + + // Returns this literal's data as a string. This literal must be a rank-1 U8 + // array. + string GetR1U8AsString() const; + + // Returns a string representation of the literal value. + // Warning: this function can take minutes for multi-million element Literals. + string ToString(bool print_layout = false) const; + + // Gets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const; + // Overloads of Get for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + NativeT GetFirstElement() const; + + // As Get(), but determines the correct type and converts the value + // into text. + string GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index = {}) const; + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + // As Get(), but determines the correct type and converts the value into + // int64. This literal must be an array. + StatusOr GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + template + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + // + // This literal must have a dense layout. + void EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const; + template + void EachCell(std::function indices, + NativeT value)> + per_cell) const; + + // Returns whether every element in this literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. Also if this literal is not array-shaped false is returned. + bool IsAll(int8 value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. Also if this literal is not array-shaped false is returned. + bool IsAllFloat(float value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. + bool IsAllComplex(complex64 value) const; + + // Literal consists entirely of the first element of the literal. + bool IsAllFirst() const; + + // Returns whether this literal is zero at the specified index. This literal + // must be an array with a dense layout. + bool IsZero(tensorflow::gtl::ArraySlice indices) const; + + // Returns the count of the elements in the array at the given shape index in + // this literal. + int64 element_count(const ShapeIndex& index = {}) const { + return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + } + + // Returns the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + + // Compute a hash for this literal. This literal must not be a sparse tensor + // or a tuple containing a sparse tensor. + size_t Hash() const; + + // Converts this literal to the given shape. Returns an error is the + // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. + StatusOr> ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + + // Converts this literal to another primitive type using a bitcast + // conversion. The to and from primitive types must have the same bit + // width. Returns an error if the conversion is not possible. This literal + // must be array-shaped. + StatusOr> BitcastConvert( + PrimitiveType primitive_dest_type) const; + + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. This literal must be array-shaped. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; + + // Returns a literal scalar representing the first element. + Literal GetFirstScalarLiteral() const; + + // Clones the underlying buffers into a new Literal, or new + // std::unique_ptr. + Literal Clone() const; + std::unique_ptr CloneToUnique() const; + + // TODO(b/67651157): The methods below which perform computation on Literals + // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with + // evaluator code which operates on Literals. + // + // Creates a new value that has the equivalent value as this + // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, + // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} + // minor-to-major dimension layout and the value in the cell at any given + // logical index (i0, i1) will be the same. + // + // For tuple shaped literals, shape_index should be used to select the inner + // array that the new layout applies to. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + std::unique_ptr Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; + + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The + // implementation currently only supports monotonic dim0-major layouts. + // This literal must be an array. + StatusOr> Reshape( + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by broadcasting this literal with `dimensions` to + // yield a literal of shape `result_shape`. + StatusOr> Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by reordering the dimensions of this literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + // This literal must be an array. + std::unique_ptr Transpose( + tensorflow::gtl::ArraySlice permutation) const; + + // Creates a sub-array from this literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + // This literal must be an array. + std::unique_ptr Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const; + + // Creates a literal with a prepended dimension with bound "times"; e.g. a + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this + // literal replicated four times. + // This literal must be an array. + template + std::unique_ptr Replicate(int64 times) const; + + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + // + // Note: It's an antipattern to use this method then immediately call + // Literal::Populate on the result (since that results in zero initialization, + // then reinitialization. Conside if a call to MakeUnique(shape), + // followed by the call to Literal::Populate can be used instead. + static std::unique_ptr CreateFromShape(const Shape& shape); + + protected: + // A data structure representing a subshape at a particular ShapeIndex within + // the literal. For array-shaped ShapeIndexes, this data structure holds the + // pointer to the memory allocated for the array data. + class Piece { + public: + // Returns the buffer holding the array data for this piece as an array + // slice. This piece must be array-shaped. + template + tensorflow::gtl::ArraySlice data() const; + template + tensorflow::gtl::MutableArraySlice data(); + + // Returns the buffer holding the array data for this piece as a void*. This + // piece must be array-shaped. + void* untyped_data(); + const void* untyped_data() const; + + // Gets or sets an element in the array at the given index. The multi_index + // is CHECKed against the dimension sizes of the array. This piece must be + // array-shaped. + template + NativeT Get(tensorflow::gtl::ArraySlice index) const; + template + void Set(tensorflow::gtl::ArraySlice index, NativeT value); + + // Gets/sets the buffer holding the array data. + char* buffer() const { return buffer_; } + void set_buffer(char* buffer) { buffer_ = buffer; } + + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } + + // Gets or sets the subshape of this piece. This reference points to a + // subshape within the shape in the containing Literal (Literal::shape_). + const Shape& subshape() const { return *subshape_; } + void set_subshape(const Shape* subshape) { subshape_ = subshape; } + + // Returns the size in bytes of the buffer holding the array data. + int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } + + // Returns the number of elements in this piece's array. + int64 element_count() const { + // If this is a sparse array, use the number of elements represented by + // the indices in the associated SparseIndexArray. + return LayoutUtil::IsSparseArray(subshape()) + ? sparse_indices()->index_count() + : ShapeUtil::ElementsIn(subshape()); + } + + // Returns the child piece at 'index' of this piece. + Piece& child(int64 index) { return children_[index]; } + + // Adds a child piece to this piece's children. + void emplace_back(Piece child_piece) { + children_.emplace_back(std::move(child_piece)); + } + + // Returns the size of children pieces of this piece. + int64 children_size() { return children_.size(); } + + // Visitor functions that recursively traverses the piece and calls the + // given function at each child piece. The function has the type: + // void (const ShapeIndex& index, const Piece& piece) + template + void ForEachSubpiece(const Fn& func) const { + ShapeIndex index; + return ForEachHelper( + [&func](const ShapeIndex& index, const Piece& piece) { + func(index, piece); + return Status::OK(); + }, + *this, &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, const Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachSubpieceWithStatus(const Fn& func) const { + ShapeIndex index; + return ForEachHelper(func, *this, &index); + } + // Same as above, but the function has the type: + // Bool (const ShapeIndex& index, const Piece& piece) + // The first non-true return value is returned by the function. + template + bool ForEachSubpieceWithBool(const Fn& func) const { + ShapeIndex index; + return ForEachHelperBool(func, *this, &index); + } + // Same as above, but the function has the type: + // Void (const ShapeIndex& index, Piece& piece) + template + void ForEachMutableSubpiece(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + [&func](const ShapeIndex& index, Piece* piece) { + func(index, piece); + return Status::OK(); + }, + const_cast(this), &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachMutableSubpieceWithStatus(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + func, const_cast(this), &index); + } + + // Returns true if this piece and 'other' contain the same data. This piece + // and 'other' must be array-shaped and compatible. + bool EqualElements(const Piece& other) const; + + // Writes the shape and data (if array-shaped) into the given proto. + void WriteToProto(LiteralProto* proto) const; + + // Copy the data from 'src' into this piece's buffer. Shapes of this piece + // and src must be compatible. + Status CopyFrom(const Piece& src); + + // Copies the data from the given proto into this piece. The shape of this + // piece must be equal (not just compatible) to the shape of the proto. + Status CopyFromProto(const LiteralProto& proto); + + // Sorts the elements in a sparse array. + void SortSparseElements(); + + private: + // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. + // The first non-OK (or non-true) value is returned by the function. + // The callable 'func' has the same signature as described above in + // ForEachSubpiece*. + template + Status ForEachHelper(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + template + bool ForEachHelperBool(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + if (!func(*index, piece)) { + return false; + } + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + if (!ForEachHelperBool(func, piece.children_[i], index)) { + return false; + } + index->pop_back(); + } + return true; + } + template + Status ForEachMutableHelper(const Fn& func, Piece* piece, + ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece->children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR( + ForEachMutableHelper(func, &piece->children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + + // Recursive helper for EqualElements. + template + bool EqualElementsInternal(const Piece& other, + std::vector* multi_index) const; + + // Helper for SortSparseElements that has the element type as a template + // parameter. + template + void SortSparseElementsInternal(); + + // For array-shaped pieces, this is the buffer holding the literal data. + char* buffer_ = nullptr; + + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; + + // The shape of piece. This points into the shape of the containing Literal + // (Literal::shape_). + const Shape* subshape_ = nullptr; + + // Children pieces for tuple shaped pieces. + std::vector children_ = {}; + }; // class Piece + + const Piece& piece(const ShapeIndex& shape_index) const { + Piece* piece = &const_cast(root_piece()); + for (const auto i : shape_index) { + DCHECK_GE(i, 0); + DCHECK_LT(i, piece->children_size()); + piece = &piece->child(i); + } + return *piece; + } + + // Returns the piece at the root of the shape. + virtual const Piece& root_piece() const = 0; + + // LiteralSlice and Literal must access Pieces of other Literals. + friend class Literal; + friend class LiteralSlice; + friend class BorrowingLiteral; + + private: + template + std::unique_ptr SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const; +}; + // Class representing literal values in XLA. // -// TODO(b/67651157): The methods in this class should be reduced to a minimal -// set of methods which construct Literals and accessors methods. Other methods -// which perform computation on Literals (Reshape, Slice, etc) should be moved -// elsewhere, and perhaps combined with evaluator code which operates on -// Literals. -class Literal { +// The underlying buffer and shape is always owned by this class. +class Literal : public LiteralBase { public: Literal() : Literal(ShapeUtil::MakeNil()) {} @@ -80,46 +574,156 @@ class Literal { Literal(const Shape& shape, bool allocate_arrays); Literal& operator=(Literal&& other); - // Literals are equal if they have compatible shapes and the same data - // values. Layout is not compared. - bool operator==(const Literal& other) const; - bool operator!=(const Literal& other) const { return !(*this == other); } + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } - // Serialize to and from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); - LiteralProto ToProto() const; + // Returns a MutableArraySlice view of the array for this literal for the + // given NativeT (e.g., float). CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. See primitive_util.h for the mapping from + // XLA type to native type. + template + tensorflow::gtl::MutableArraySlice data( + const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::data; + + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + + // Returns a pointer to the underlying buffer holding the array at the given + // shape index. CHECKs if the subshape of the literal at the given ShapeIndex + // is not array. + void* untyped_data(const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::untyped_data; + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + + // Copy values from 'src_literal' rooted at 'src_shape_index' into this + // literal rooted at 'dest_shape_index'. The subshape of this literal rooted + // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' + // rooted at 'src_shape_index', but need not be arrays. + Status CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index = {}, + const ShapeIndex& src_shape_index = {}); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to this literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and this literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + // Note: if either src_literal or this literal contains dimensions with zero + // element, then copy_size must be 0 in these dimensions while the + // corresponding base indices being 0. + // This literal and 'src_literal' must be arrays. + Status CopySliceFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); - // Return the shape of the literal. - const Shape& shape() const { return shape_; } + // Copies one element from src_literal[src_index] to (*this)[dest_index]. + Status CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index); + + // Sets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + void Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value); + // Overloads of Set for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. + template + void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); + + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value); + + // Populate this literal with the given values. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // literal.PopulateR2FromArray2D(values); + // + // // Populate with int32s. + // literal.PopulateR2({{1, 2}, {3, 4}}); + // + // The shape and element type of this literal must match given values. For + // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 + // array of S32. + template + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); + template + void PopulateR2(std::initializer_list> values); + template + void PopulateFromArray(const Array& values); + template + void PopulateR2FromArray2D(const Array2D& values); + template + void PopulateR3FromArray3D(const Array3D& values); + template + void PopulateR4FromArray4D(const Array4D& values); + + // Populates literal values by calling the generator function for every cell + // in this literal object. + // + // generator must be a callable of the type + // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. + template + Status Populate(const FnType& generator); - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return &shape_; } + // A parallel version of Populate(). This can be used if the generator is + // thread-safe and the values for the shape's different elements are + // independent. + template + Status PopulateParallel(const FnType& generator); - // Returns a (Mutable)ArraySlice view of the array for this literal for the - // given NativeT (e.g., float). CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. See primitive_util.h for the mapping from - // XLA type to native type. - template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; + // Fills this literal with the given value. template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); + void PopulateWithValue(NativeT value); - // Returns a pointer to the sparse index array. Returns nullptr if the literal - // is not a sparse array. - const SparseIndexArray* sparse_indices( - const ShapeIndex& shape_index = {}) const; - SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // Factory methods below. + // - // Returns a pointer to (or size of) the underlying buffer holding the array - // at the given shape index. CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. - const void* untyped_data(const ShapeIndex& shape_index = {}) const; - void* untyped_data(const ShapeIndex& shape_index = {}); - int64 size_bytes(const ShapeIndex& shape_index = {}) const; + // Serialize from a proto. + static StatusOr> CreateFromProto( + const LiteralProto& proto); // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the @@ -167,10 +771,6 @@ class Literal { values, const Layout& layout); - // Returns this literal's data as a string. This literal must be a rank-1 U8 - // array. - string GetR1U8AsString() const; - // Creates a literal with a sparse layout and the given indices and values. // The shape is initialized from the given dimensions. The minor dimension of // the indices array must equal the rank of the shape (i.e. size of the @@ -210,171 +810,16 @@ class Literal { tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, tensorflow::gtl::ArraySlice values, bool sort = true); - // Populates a literal with a sparse layout with the given indices and values. - // Each index in the indices array is CHECKed against the dimensions in the - // literal's shape. If sort is true, then the indices and values will be - // sorted. If sort is false, then the indices and values are assumed to - // already be in sorted order. See CreateSparse for an example of how data - // are populated. - template - void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); - - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromShape(const Shape& shape); - - // Creates a new Literal object with its values havings the primitive_type - // type, and with dimensions defined by the dimensions parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions); - - // Copy values from 'src_literal' rooted at 'src_shape_index' into this - // literal rooted at 'dest_shape_index'. The subshape of this literal rooted - // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' - // rooted at 'src_shape_index', but need not be arrays. - Status CopyFrom(const Literal& src_literal, - const ShapeIndex& dest_shape_index = {}, - const ShapeIndex& src_shape_index = {}); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - - // Copies the values from src_literal, starting at src_base shape indexes, - // to this literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // The src_literal and this literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - // Note: if either src_literal or this literal contains dimensions with zero - // element, then copy_size must be 0 in these dimensions while the - // corresponding base indices being 0. - // This literal and 'src_literal' must be arrays. - Status CopySliceFrom(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); - - // Copies one element from src_literal[src_index] to (*this)[dest_index]. - Status CopyElementFrom(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); - - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); - - // This operation is the inverse of DecomposeTuple. The given elements are - // moved into the tuple elements of a new tuple-shaped Literal which is - // returned. Upon return, each of the Literals in 'elements' is set to a nil - // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); - - // Creates a new value that has the equivalent value as this literal, but - // conforms to new_layout; e.g. a literal matrix that was in {0, 1} - // minor-to-major dimension layout can be re-laid-out as {1, 0} - // minor-to-major dimension layout and the value in the cell at any given - // logical index (i0, i1) will be the same. - // - // For tuple shaped literals, shape_index should be used to select the inner - // array that the new layout applies to. - // - // Note: this is useful when the client wants to ensure that a value placed in - // the XLA allocation tracker has a particular layout; for efficiency - // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; - - // An overload of Relayout which changes the layout of the entire shape rather - // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; - - // Creates a new literal by reshaping this literal to have the given - // dimensions. The total number of elements must not change; The - // implementation currently only supports monotonic dim0-major layouts. - // This literal must be an array. - StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; - - // Creates a new literal by reordering the dimensions of this literal. - // The given `permutation` must be a permutation of the dimension numbers - // in the original literal, and it specifies the order of the new dimensions - // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). - // For example, a transpose call on a literal of shape [3 x 8 x 4] and - // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; - - // Creates a sub-array from this literal by extracting the indices - // [start_index, limit_index) of each dimension. The result literal has the - // same rank and layout as for the given literal. The number of indices in - // start_indices and limit_indices must be the rank of the literal, and the - // indices follow the order of the dimensions. - // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; - - // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this - // literal replicated four times. - // This literal must be an array. - template - std::unique_ptr Replicate(int64 times) const; - - // Converts this literal to another primitive type using - // static_cast<>. Returns an error if the conversion is not possible. This - // literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to another primitive type using a bitcast - // conversion. The to and from primitive types must have the same bit - // width. Returns an error if the conversion is not possible. This literal - // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to the given shape. Returns an error is the - // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; - // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); - // Creates a scalar literal value one of the given primitive type. static Literal One(PrimitiveType primitive_type); - // Creates a scalar literal value containing the minimum value of the given // primitive type. For floating-point types, returns -inf. static Literal MinValue(PrimitiveType primitive_type); - // Creates a scalar literal value containing the maximum value of the given // primitive type. For floating-point types, returns inf. static Literal MaxValue(PrimitiveType primitive_type); - // Creates a literal of the given shape where each element is `value`. template static std::unique_ptr CreateFullWithDescendingLayout( @@ -423,84 +868,11 @@ class Literal { int64 projection); // Creates a literal that projects the (x, y) dimensions given in values into - // the z and p dimensions given. - template - static std::unique_ptr CreateR4Projected( - std::initializer_list> values, - int64 projection_p, int64 projection_z); - - // Clones this literal into a new Literal, or new std::unique_ptr. - Literal Clone() const; - std::unique_ptr CloneToUnique() const; - - // Gets or sets an element in the literal at the given index. The multi_index - // is CHECKed against the dimension sizes. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); - - // Overloads of Get and Set for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - - // Returns the multi-index of the element in a sparse literal at the given - // sparse element number. The sparse element number is the position with in - // the sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; - - // Returns the value of the element in a sparse literal at the given sparse - // element number. The sparse element number is the position with in the - // sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - template - NativeT GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // Appends the given element to the literal. If the elements are not appended - // in sorted order, then SortSparseElements should be called before calling - // other methods. This literal must have a sparse layout. - template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); - - // Sorts the elements in a sparse array. - void SortSparseElements(const ShapeIndex& shape_index = {}); - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - NativeT GetFirstElement() const; - - // Returns a literal scalar representing the first element. - Literal GetFirstScalarLiteral() const; - - // As Get(), but determines the correct type and converts the value - // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index = {}) const; - - // As GetSparseElement(), but determines the correct type and converts the - // value into text. - string GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // As Get(), but determines the correct type and converts the value into - // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; - - // As Set(), but truncates `value` to the literal element type before storing. - // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); + // the z and p dimensions given. + template + static std::unique_ptr CreateR4Projected( + std::initializer_list> values, + int64 projection_p, int64 projection_z); // Returns an identity matrix (rank 2) with the given row and column count. template @@ -511,6 +883,9 @@ class Literal { static std::unique_ptr MakeTuple( tensorflow::gtl::ArraySlice elements); + static std::unique_ptr MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements); + // As above, but intended to be invoked with move semantics; i.e. // // std::vector> elements = ...; @@ -542,135 +917,104 @@ class Literal { return MakeTupleOwned(std::move(v)); } - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; - - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - // - // This literal must have a dense layout. - void EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const; - template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; - - // Populate this literal with the given values. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // literal.PopulateR2FromArray2D(values); - // - // // Populate with int32s. - // literal.PopulateR2({{1, 2}, {3, 4}}); - // - // The shape and element type of this literal must match given values. For - // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 - // array of S32. - template - void PopulateR1(tensorflow::gtl::ArraySlice values); - void PopulateR1(const tensorflow::core::Bitmap& values); - template - void PopulateR2(std::initializer_list> values); - template - void PopulateFromArray(const Array& values); - template - void PopulateR2FromArray2D(const Array2D& values); - template - void PopulateR3FromArray3D(const Array3D& values); - template - void PopulateR4FromArray4D(const Array4D& values); - - // Populates literal values by calling the generator function for every cell - // in this literal object. - // - // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. - // - // This literal must have a dense layout. - template - Status Populate(const FnType& generator); - - // A parallel version of Populate(). This can be used if the generator is - // thread-safe and the values for the shape's different elements are - // independent. - template - Status PopulateParallel(const FnType& generator); + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); - // Fills this literal with the given value. - template - void PopulateWithValue(NativeT value); + // This operation is the inverse of DecomposeTuple. The given elements are + // moved into the tuple elements of a new tuple-shaped Literal which is + // returned. Upon return, each of the Literals in 'elements' is set to a nil + // shape (empty tuple). + static Literal MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements); - // Returns whether every element in this literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in this literal's type, returns false. Values of 1/0 - // are considered equal to true/false; other values are not considered equal - // to true. Also if this literal is not array-shaped false is returned. - bool IsAll(int8 value) const; + // Creates a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions); - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. Also if this literal is not array-shaped false is returned. - bool IsAllFloat(float value) const; + // If the given literal's data type is bfloat16, converts it to a float + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertBF16ToF32( + const LiteralSlice& bf16_literal); + + // If the given literal's data type is float, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertF32ToBF16( + const LiteralSlice& f32_literal); + + // Creates a literal with a new shape with the given new dimensions using the + // data in the given input literal. For reshaping purposes the (flat) data + // buffer of the input literal is assumed to have the given minor_to_major + // layout order. + static std::unique_ptr ReshapeSlice( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, + const LiteralSlice& literal); + + // Creates a literal with the supplied shape, and uses the provided value + // generator to populate the literal's values. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation, and using the engine as entropy generator. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, typename E, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, T mean, T stddev); - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular complex number. - // - // If the literal is not a complex value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for complex values that can be expressed precisely as - // float pairs e.g. (-0.5, 1.0). // - // This literal must have a dense layout. - bool IsAllComplex(complex64 value) const; + // End of factory methods. - // Literal consists entirely of the first element of the literal. - bool IsAllFirst() const; + // Returns a multi-dimensional index as a string. For example: '{7, 8}' will + // be returned for a 2-dimensional index with dimension 0 index equal to 7, + // dimension 1 equal to 8. + static string MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index); - // Returns whether this literal is zero at the specified index. This literal - // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; + private: + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - // Return the count of the elements in the array at the given shape index in - // this literal. - int64 element_count(const ShapeIndex& index = {}) const { - return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + // Returns the piece at the given ShapeIndex. + Piece& piece(const ShapeIndex& shape_index) { + return const_cast(LiteralBase::piece(shape_index)); } - // Return the count of the elements in the sparse array at the given shape - // index in this literal, which will be no larger than - // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). - int64 sparse_element_count() const; - - // Compute a hash for this literal. This literal must not be a sparse tensor - // or a tuple containing a sparse tensor. - size_t Hash() const; + Piece& root_piece() const override { return *root_piece_; }; - protected: // Internal template helper for the Literal::CopySliceFrom(), matching its // arguments one by one. template - Status CopySliceFromInternal(const Literal& src_literal, + Status CopySliceFromInternal(const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size); @@ -698,162 +1042,71 @@ class Literal { int64 minor_loop_size = 1; }; - // A data structure representing a subshape at a particular ShapeIndex within - // the literal. For array-shaped ShapeIndexes, this data structure holds the - // pointer to the memory allocated for the array data. - class Piece { - public: - // Return the buffer holding the array data for this piece as an array - // slice. This piece must be array-shaped. - template - tensorflow::gtl::ArraySlice data() const; - template - tensorflow::gtl::MutableArraySlice data(); - - // Return the buffer holding the array data for this piece as a void*. This - // piece must be array-shaped. - void* untyped_data(); - const void* untyped_data() const; - - // Gets or sets an element in the array at the given index. The multi_index - // is CHECKed against the dimension sizes of the array. This piece must be - // array-shaped. - template - NativeT Get(tensorflow::gtl::ArraySlice index) const; - template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); - - // Gets/sets the buffer holding the array data. - char* buffer() const { return buffer_; } - void set_buffer(char* buffer) { buffer_ = buffer; } - - // The array of multi-indices that provide the locations of non-zero - // elements in a sparse array. Only used if - // LayoutUtil::IsSparseArray(shape()) is true. - SparseIndexArray* sparse_indices() const { return sparse_indices_; } - void set_sparse_indices(SparseIndexArray* sparse_indices) { - sparse_indices_ = sparse_indices; - } - - // Gets or sets the subshape of this piece. This reference points to a - // subshape within the shape in the containing Literal (Literal::shape_). - const Shape& subshape() const { return *subshape_; } - void set_subshape(const Shape* subshape) { subshape_ = subshape; } - - // Returns the size in bytes of the buffer holding the array data. - int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } - - // Returns the number of elements in this piece's array. - int64 element_count() const { - // If this is a sparse array, use the number of elements represented by - // the indices in the associated SparseIndexArray. - return LayoutUtil::IsSparseArray(subshape()) - ? sparse_indices()->index_count() - : ShapeUtil::ElementsIn(subshape()); - } - - // Copy the data from 'src' into this piece's buffer. Shapes of this piece - // and src must be compatible. - Status CopyFrom(const Piece& src); - - // Returns true if this piece and 'other' contain the same data. This piece - // and 'other' must be array-shaped and compatible. - bool EqualElements(const Piece& other) const; - - // Writes the shape and data (if array-shaped) into the given proto. - void WriteToProto(LiteralProto* proto) const; - - // Copies the data from the given proto into this piece. The shape of this - // piece must be equal (not just compatible) to the shape of the proto. - Status CopyFromProto(const LiteralProto& proto); - - // Sorts the elements in a sparse array. - void SortSparseElements(); - - private: - // Recursive helper for EqualElements. - template - bool EqualElementsInternal(const Piece& other, - std::vector* multi_index) const; - - // Helper for SortSparseElements that has the element type as a template - // parameter. - template - void SortSparseElementsInternal(); - - // For array-shaped pieces, this is the buffer holding the literal data. - char* buffer_ = nullptr; - - // For sparse arrays, this is the array of indices. - SparseIndexArray* sparse_indices_ = nullptr; - - // The shape of piece. This points into the shape of the containing Literal - // (Literal::shape_). - const Shape* subshape_ = nullptr; - }; - - // Returns the piece at the given ShapeIndex. - Piece& piece(const ShapeIndex& shape_index) { - return *pieces_.mutable_element(shape_index); - } - const Piece& piece(const ShapeIndex& shape_index) const { - return pieces_.element(shape_index); - } - - // Returns the piece at the root of the shape (empty ShapeIndex). - Piece& root_piece() { return piece({}); } - const Piece& root_piece() const { return piece({}); } + // Literal class always owns the shape. The parent class borrows this shape. + std::unique_ptr shape_; - // Deallocate the buffers held by this literal (if the literal owns the - // buffer). - void DeallocateBuffers(); + Piece* root_piece_ = nullptr; // Implementation details shared between Populate() and PopulateParallel() template Status PopulateInternal(const FnType& generator, bool parallel); - Shape shape_; - ShapeTree pieces_; - - // Whether the buffers held in pieces_ are owned by this Literal. - bool owns_buffers_; - - // LiteralView must access and manipulate Pieces of other Literals. - friend class LiteralView; -}; // namespace xla + // Deallocate the buffers held by this literal. + void DeallocateBuffers(); + friend class LiteralBase; +}; std::ostream& operator<<(std::ostream& out, const Literal& literal); -// A read-only view of a Literal. A LiteralView contains pointers to buffers -// owned by the viewed Literal. -// -// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable -// and mutable) similar to (Mutable)ArraySlice. -class LiteralView : public Literal { +// A read-only view of a Literal. A LiteralSlice contains pointers to shape and +// literal buffers always owned by others. +class LiteralSlice : public LiteralBase { public: - // Create and return a view of the given literal rooted at the given shape - // index within the given literal. A factory is used rather than a public - // constructor because only const LiteralViews are supported. It's still - // possible to create non-const LiteralViews via the copy constructors, but - // the factory method makes it a bit less likely. Implementing literal slices - // will fix this undesirable situation (b/71550060). - static const LiteralView Create(const Literal& literal, - const ShapeIndex& view_root = {}); - - LiteralView(const LiteralView& other); - LiteralView& operator=(const LiteralView& other); + LiteralSlice() : LiteralBase() {} - virtual ~LiteralView(); + // Implicit conversion constructors. + LiteralSlice(const LiteralBase& literal); + LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); private: - LiteralView(const Literal& literal, const ShapeIndex& view_root); + const Piece& root_piece() const override { return *root_piece_; }; + + const Piece* root_piece_; // Not owned. +}; + +// A read-only Literal where the underlying buffers are never owned by this +// class. +class BorrowingLiteral : public LiteralBase { + public: + BorrowingLiteral() : LiteralBase() {} + + // 'src_buf_ptr' is not owned by this class and must outlive the + // lifetime of this class. It points to an appropirately sized buffer with + // data interpretered as indicated by 'shape'. + // This constructor is only used for array shapes. + BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + // Similar as above, except to be used for constructing non-nested tuples. + BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + const Shape& shape); + // TODO(b/79707221): adding constructors for nested tuples as well. - // Helper for the copy constructor and copy assignment operator. - void CopyFrom(const LiteralView& other); + private: + // Recursively builds the subtree for the given piece and sets the subshapes + // of the given piece with the given shape. + void BuildPieceSubtree(const Shape& shape, Piece* piece); + + // Accessor for the root piece of this literal. + const Piece& root_piece() const override { return root_piece_; }; + Piece root_piece_; + + // Shape of this literal. Stored as unique_ptr so such that the (default) + // move construction of this class would be trivially correct: the pointer to + // Shape root_piece_ stores will still point to the correct address. + std::unique_ptr shape_; }; template -tensorflow::gtl::ArraySlice Literal::Piece::data() const { +tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -866,7 +1119,7 @@ tensorflow::gtl::ArraySlice Literal::Piece::data() const { } template -tensorflow::gtl::MutableArraySlice Literal::Piece::data() { +tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -879,7 +1132,7 @@ tensorflow::gtl::MutableArraySlice Literal::Piece::data() { } template -NativeT Literal::Piece::Get( +NativeT LiteralBase::Piece::Get( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(subshape())); return data()[IndexUtil::MultidimensionalIndexToLinearIndex( @@ -887,15 +1140,15 @@ NativeT Literal::Piece::Get( } template -void Literal::Piece::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { +void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { CHECK(LayoutUtil::IsDenseArray(subshape())); data()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)] = value; } template -tensorflow::gtl::ArraySlice Literal::data( +tensorflow::gtl::ArraySlice LiteralBase::data( const ShapeIndex& shape_index) const { return piece(shape_index).data(); } @@ -907,13 +1160,13 @@ tensorflow::gtl::MutableArraySlice Literal::data( } template -inline NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { return piece(shape_index).Get(multi_index); } template -inline NativeT Literal::Get( +inline NativeT LiteralBase::Get( tensorflow::gtl::ArraySlice multi_index) const { return root_piece().Get(multi_index); } @@ -1160,13 +1413,13 @@ template } template -NativeT Literal::GetFirstElement() const { +NativeT LiteralBase::GetFirstElement() const { return data().at(0); } template -NativeT Literal::GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { CHECK( LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); return data(shape_index)[sparse_element_number]; @@ -1199,7 +1452,7 @@ template } template -void Literal::EachCell( +void LiteralBase::EachCell( std::function indices, NativeT value)> per_cell) const { @@ -1375,7 +1628,7 @@ template } template -std::unique_ptr Literal::Replicate(int64 times) const { +std::unique_ptr LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { @@ -1410,6 +1663,38 @@ std::unique_ptr Literal::Replicate(int64 times) const { return literal; } +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + TF_RET_CHECK(shape.element_type() == type); + auto literal = MakeUnique(shape); + TF_RETURN_IF_ERROR(literal.get()->Populate( + [&](tensorflow::gtl::ArraySlice indexes) { + return generator(indexes); + })); + return std::move(literal); +} + +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + std::normal_distribution generator(mean, stddev); + return CreateRandomLiteral( + shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { + return generator(*engine); + }); +} + +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, T mean, T stddev) { + std::minstd_rand0 engine; + return CreateRandomLiteral(shape, &engine, mean, stddev); +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 61046784e05623cd3117c24ecc6d6c474739bbd5..53b926163c472c3ed7b72bf8b035d13996d59e34 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -974,7 +975,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { Literal::CreateR1({2.0, 4.0}).get(), &nil_literal}); - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); @@ -985,7 +986,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { /*src_shape_index=*/{})); // The matrix element should be unchanged. - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); // The tuple element should have been copied from 'tuple'. EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); @@ -1065,7 +1066,7 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1107,7 +1108,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1373,36 +1374,36 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } -TEST_F(LiteralUtilTest, LiteralViewTest) { +TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar); - EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix); - EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple); - EXPECT_EQ(LiteralView::Create(nil, {}), nil); + EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); + EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralSlice(nil, {}), nil); - EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar); - EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar); + EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix); + EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar); } -TEST_F(LiteralUtilTest, MutatingLiteralView) { +TEST_F(LiteralUtilTest, MutatingLiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); EXPECT_EQ( nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1.0f); @@ -1418,19 +1419,57 @@ TEST_F(LiteralUtilTest, MutatingLiteralView) { 555.0f); } -TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) { +TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); - const auto tuple_view = - LiteralView::Create(nested_tuple_view, /*view_root=*/{0}); - const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1}); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); + const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); + const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } +TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { + std::vector int64_values = {1, 2, 3}; + const Shape literal_shape = ShapeUtil::MakeShape(S64, {3}); + + BorrowingLiteral literal(reinterpret_cast(int64_values.data()), + literal_shape); + + EXPECT_EQ(literal.Get({0}), 1); + EXPECT_EQ(literal.Get({1}), 2); + EXPECT_EQ(literal.Get({2}), 3); +} + +TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { + std::vector one_two_three = {1, 2, 3}; + const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3}); + + std::vector hundred = {100}; + const Shape hundred_shape = ShapeUtil::MakeShape(S64, {1}); + + std::vector src_buf_ptrs; + src_buf_ptrs.emplace_back( + reinterpret_cast(one_two_three.data())); + src_buf_ptrs.emplace_back(reinterpret_cast(hundred.data())); + auto literal_tuple = BorrowingLiteral( + src_buf_ptrs, + ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape})); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{0}, /*shape_index=*/{0}), + 1); + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{0}, /*shape_index=*/{1}), + 100); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{1}, /*shape_index=*/{0}), + 2); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{2}, /*shape_index=*/{0}), + 3); +} + TEST_F(LiteralUtilTest, LiteralMove) { std::unique_ptr matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); @@ -1533,11 +1572,11 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { EXPECT_EQ(literal.Get({1, 1}), 4.0); } -TEST_F(LiteralUtilTest, LiteralViewCopy) { +TEST_F(LiteralUtilTest, LiteralSliceCopy) { std::unique_ptr matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - const auto matrix_view = LiteralView::Create(*matrix); - LiteralView matrix_view_copy(matrix_view); + const auto matrix_view = LiteralSlice(*matrix); + LiteralSlice matrix_view_copy(matrix_view); EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); EXPECT_EQ(matrix_view_copy.Get({0, 1}), 2.0); @@ -1771,5 +1810,35 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } +TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { + std::unique_ptr literal = Literal::CreateR1({1, 2}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{0})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 1}, {2, 2}})); +} + +TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { + std::unique_ptr literal = Literal::CreateR1({1, 2}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{1})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 2}, {1, 2}})); +} + +TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { + std::unique_ptr literal = Literal::CreateR0(9); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), + /*dimensions=*/{})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{9, 9}, {9, 9}})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 8db8c6f3de84a6c46625eadbb6b0f83d2262e5f7..3c74e070da529b7f1431e01fbaf31932f582db44 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -86,11 +86,10 @@ const typename Collection::value_type::second_type& FindOrDefault( // Inserts the key-value pair into the collection. Dies if key was already // present. -template -void InsertOrDie(Collection* const collection, - const typename Collection::value_type::first_type& key, - const typename Collection::value_type::second_type& data) { - auto p = collection->insert(std::make_pair(key, data)); +template +void InsertOrDie(Collection* const collection, Key&& key, Value&& value) { + auto p = collection->insert( + std::make_pair(std::forward(key), std::forward(value))); CHECK(p.second) << "duplicate key: " << key; } @@ -101,9 +100,10 @@ bool ContainsKey(const Collection& collection, const Key& key) { } // Inserts `value` into `set`. Dies if it was already present. -template -void InsertOrDie(Set* const set, const typename Set::value_type& value) { - CHECK(set->insert(value).second) << "duplicate value: " << value; +template +void InsertOrDie(Set* const set, Value&& value) { + CHECK(set->insert(std::forward(value)).second) + << "duplicate value: " << value; } } // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 932cce943f7c046a85984e6e5ed6b59dae371473..83834c1ff65ea2f9989fe08279c29056d9070adb 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -12,6 +12,7 @@ py_library( deps = [ ":pywrap_xla", "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/compiler/xla/service:hlo_proto_py", ], ) @@ -53,6 +54,7 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index df262c97bfcd91a5c2921a36ecb8f8a6172cffe6..ac058feccd3593ce12243d67f61aa8228c76af29 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/default/thread_annotations.h" +#include "tensorflow/core/platform/thread_annotations.h" namespace xla { @@ -276,6 +276,15 @@ const XlaComputation& LocalComputation::computation() const { return computation_; } +string LocalComputation::GetSerializedProto() const { + string result; + if (!computation_.proto().SerializeToString(&result)) { + LOG(ERROR) << "Failed to serialize the HloModuleProto."; + return ""; + } + return result; +} + StatusOr LocalComputation::GetReturnValueShape() const { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation_.GetProgramShape()); @@ -589,10 +598,12 @@ _FORWARD_BINOP(Or) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) +_FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) +_FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index a06b85b4ea28c4f386598901138930eaaed12079..e30c7790b9e627f309203dc371b1fcf90a9b3345 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -112,6 +112,11 @@ class LocalComputation { const XlaComputation& computation() const; + // Returns the HloModuleProto contained in the XlaComputation in the + // serialized binary format. Logs an internal error and returns an empty + // string on failure. + string GetSerializedProto() const; + // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; @@ -300,10 +305,12 @@ class LocalComputationBuilder { _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) + _FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) + _FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 04c56bbba95fbf3248df6c49700ff563c8b253c0..fcd30b6c2f851dea2b206497bbb5d4cbfceb99e7 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -851,6 +851,11 @@ tensorflow::ImportNumpy(); })) { return nullptr; } + if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) { + build_options.set_dump_unoptimized_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); })) { @@ -906,6 +911,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalComputation::GetSerializedProto; %unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; @@ -968,10 +974,12 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::LocalComputationBuilder::Exp; +%unignore xla::swig::LocalComputationBuilder::Expm1; %unignore xla::swig::LocalComputationBuilder::Floor; %unignore xla::swig::LocalComputationBuilder::Ceil; %unignore xla::swig::LocalComputationBuilder::Round; %unignore xla::swig::LocalComputationBuilder::Log; +%unignore xla::swig::LocalComputationBuilder::Log1p; %unignore xla::swig::LocalComputationBuilder::Sign; %unignore xla::swig::LocalComputationBuilder::Cos; %unignore xla::swig::LocalComputationBuilder::Sin; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index dc6f5fe5fcc067c99ced01812f9f2388a00766d0..68648a3a176363de69a56ecb8070f82862874e94 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -340,13 +340,13 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { return result; } -PyObject* PyObjectFromXlaLiteral(const Literal& literal) { +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); PyObject* tuple = PyTuple_New(num_elements); for (int i = 0; i < num_elements; i++) { - PyTuple_SET_ITEM( - tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i}))); + PyTuple_SET_ITEM(tuple, i, + PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); } return tuple; } else { @@ -431,7 +431,7 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, return Status::OK(); } -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array) { switch (np_type) { case NPY_BOOL: diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 9656cb1c31c39dbe54293700c2765d0723255657..64f0aae0f9790f0199ac6cb931a5c9f6dc356f4c 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -74,7 +74,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o); // array data. // // The return value is a new reference. -PyObject* PyObjectFromXlaLiteral(const Literal& literal); +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); // Converts a Numpy ndarray or a nested Python tuple thereof to a // corresponding XLA literal. @@ -90,7 +90,7 @@ StatusOr > XlaLiteralFromPyObject(PyObject* o); Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Literal* literal); -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array); template @@ -101,7 +101,8 @@ void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { } template -void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { +void CopyLiteralToNumpyArray(const LiteralSlice& literal, + PyArrayObject* py_array) { NativeT* dest = static_cast(PyArray_DATA(py_array)); auto source = literal.data(); std::copy(source.begin(), source.end(), dest); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 1d5b75d1bee2dcee3e448d0bcb72103b539efac6..8b03682892bff4948d273491e3176b8dda8d5e77 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -28,6 +28,7 @@ import numpy as np from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api +from tensorflow.compiler.xla.service import hlo_pb2 # Most functions are snake_case for consistency with other modules, whereas @@ -88,10 +89,12 @@ _UNARY_OPS = [ 'Not', 'Abs', 'Exp', + 'Expm1', 'Floor', 'Round', 'Ceil', 'Log', + 'Log1p', 'Sign', 'Cos', 'Sin', @@ -352,6 +355,7 @@ class CompileOptions(object): def __init__(self): self.generate_hlo_graph = None self.dump_optimized_hlo_proto_to = None + self.dump_unoptimized_hlo_proto_to = None self.dump_per_pass_hlo_proto_to = None self.hlo_profile = False @@ -410,6 +414,17 @@ class LocalComputation(object): assert isinstance(c_local_computation, c_api.LocalComputation) self._delete = c_api.DeleteLocalComputation + def GetProto(self): + """Get the HloModuleProto proto object in this local computation. + + Returns: + An HloModuleProto proto object that has the whole-graph information. + """ + + serialized = self.c_local_computation.GetSerializedProto() + proto = hlo_pb2.HloModuleProto.FromString(serialized) + return proto + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): """Compiles an un-compiled local computation. @@ -1100,6 +1115,61 @@ class ComputationBuilder(object): dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) return dimension_numbers + def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, + rhs_dilation, dimension_numbers): + """Enqueues a ConvGeneralDilated operation onto the computation. + + Args: + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. + window_strides: length-N array-like of integer kernel strides. + padding: length-N array-like of pairs of integers of (low, high) padding. + lhs_dilation: length-N array-like of integer dilation factors. + rhs_dilation: length-N array-like of integer dilation factors. + dimension_numbers: either an xla_data_pb2.ConvolutionDimensionNumbers or a + triple (lhs_spec, rhs_spec, out_spec) where each element is a string of + length N+2 identifying by position (1) batch dimensions in lhs, rhs, and + the output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions + in rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers + consistent with the Conv operation with two spatial dimensions, one + could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate + dimension numbers consistent with the TensorFlow Conv2D operation, one + could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of + convolution dimension specification, window strides are associated with + spatial dimension character labels according to the order in which the + labels appear in the rhs_spec string, so that window_strides[0] is + matched with the dimension corresponding to the first character + appearing in rhs_spec that is not 'I' or 'O'. + + Returns: a LocalOp representing the ConvGenralDilated operation. + """ + if not isinstance(dimension_numbers, + xla_data_pb2.ConvolutionDimensionNumbers): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) + dimension_numbers.input_spatial_dimensions.extend( + sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]))) + dimension_numbers.output_spatial_dimensions.extend( + sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]))) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index c073c02040e4d260cf760ea2b25f70d60ddd41a1..6c0680f44374e6d182190fa3b9c155fcf4148c8b 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -164,6 +164,16 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + def testGetProto(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) + built = c.Build() + proto = built.GetProto() # HloModuleProto + self.assertTrue(len(proto.computations) == 1) + self.assertTrue(len(proto.computations[0].instructions) == 3) + def testSum2DF64(self): c = self._NewComputation() c.Add( @@ -509,6 +519,46 @@ class SingleOpTest(LocalComputationTest): [40., 50., 0.]]]]) self._ExecuteAndCompareClose(c, expected=result) + def testConvGeneralDilatedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = ("NCHW", "OIHW", "NCHW") + c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvGeneralDilatedPermutedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + + dimension_numbers = ("NHWC", "OIHW", "CWNH") + c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))), + c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) @@ -521,6 +571,12 @@ class SingleOpTest(LocalComputationTest): c.Exp(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + def testExpm1(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Expm1(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.expm1(arr)) + def testRound(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -533,6 +589,12 @@ class SingleOpTest(LocalComputationTest): c.Log(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.log(arr)) + def testLog1p(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Log1p(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.log1p(arr)) + def testNeg(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 28d6a8c3fe85fa4179bf2f41c82ad4eb93a045fe..8fa6961d197dce519cf151283b8bc0836a4615c0 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -265,9 +265,9 @@ class ReferenceUtil { const Array3D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 3); - std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3()}; - std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; - std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; + const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3()}; + const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()}; + int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()}; for (int i = 0; i < 3; ++i) { if (i != concatenate_dimension) { out_dims[i] = lhs_dims[i]; @@ -299,9 +299,9 @@ class ReferenceUtil { const Array4D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 4); - std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}; - std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; - std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}; + const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; for (int i = 0; i < 4; ++i) { if (i != concatenate_dimension) { out_dims[i] = lhs_dims[i]; @@ -330,13 +330,14 @@ class ReferenceUtil { return result; } - // Slices with modulo-wrapping. + // Slices with index clamping template - static std::vector ModSlice1D(const tensorflow::gtl::ArraySlice& input, - int64 start, int64 size) { + static std::vector ClampSlice1D( + const tensorflow::gtl::ArraySlice& input, int64 start, int64 size) { + start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { - result.push_back(input[(start + i) % input.size()]); + result.push_back(input[(start + i)]); } return result; } @@ -552,12 +553,11 @@ class ReferenceUtil { const NativeT pad) { CHECK_EQ(padding.dimensions_size(), 3); - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3()}; - std::vector pad_low(3); - std::vector pad_high(3); - std::vector pad_interior(3); - std::vector output_bounds(3); + const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3()}; + int64 pad_low[3]; + int64 pad_high[3]; + int64 pad_interior[3]; + int64 output_bounds[3]; for (int64 i = 0; i < 3; ++i) { pad_low[i] = padding.dimensions(i).edge_padding_low(); pad_high[i] = padding.dimensions(i).edge_padding_high(); @@ -573,7 +573,7 @@ class ReferenceUtil { Array3D result(output_bounds[0], output_bounds[1], output_bounds[2]); - std::vector indices = {0, 0, 0}; + int indices[] = {0, 0, 0}; for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { @@ -611,12 +611,12 @@ class ReferenceUtil { const NativeT pad) { CHECK_EQ(padding.dimensions_size(), 4); - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3(), operand.n4()}; - std::vector pad_low(4); - std::vector pad_high(4); - std::vector pad_interior(4); - std::vector output_bounds(4); + const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; + int64 pad_low[4]; + int64 pad_high[4]; + int64 pad_interior[4]; + int64 output_bounds[4]; for (int64 i = 0; i < 4; ++i) { pad_low[i] = padding.dimensions(i).edge_padding_low(); pad_high[i] = padding.dimensions(i).edge_padding_high(); diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 0d56a9a477b15964ad45e798865aa8d2c7385073..1775666652f303ee095a11405537106a3eb9b056 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -42,7 +42,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], ) @@ -61,7 +61,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], ) @@ -74,6 +74,6 @@ cc_library( "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], ) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 10997c0719dfb80efc7b855c7888500caeb1591b..d7dd9786a2bbde2d18ae81a9a9d4cc4b2cc38411 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include -#include "grpc++/create_channel.h" -#include "grpc++/security/credentials.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" @@ -101,8 +101,8 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); - LiteralTestUtil::ExpectNear(*expected_literal, *result_literal, - ErrorSpec(0.0001)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal, + ErrorSpec(0.0001))); } } // namespace diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index ffb72fc73c5bc1ad6e648fb3d772eb5749700dc0..4e1435fa30a24c320ddbedb84d37b369a3158a54 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -27,24 +27,11 @@ namespace xla { return std::move(grpc_service); } -::grpc::Status DelegateRPC(std::function op) { - tensorflow::Status s = op(); +::grpc::Status DelegateRPC(std::function op) { + Status s = op(); return tensorflow::ToGrpcStatus(s); } -::grpc::Status GRPCService::Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Computation(arg, result); }); -} - -::grpc::Status GRPCService::CreateOp(::grpc::ServerContext* context, - const OpRequest* arg, OpResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Op(arg, result); }); -} - ::grpc::Status GRPCService::Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) { @@ -60,21 +47,6 @@ namespace xla { }); } -::grpc::Status GRPCService::SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - return DelegateRPC([this, arg, results]() { - return service_->SetReturnValue(arg, results); - }); -} - -::grpc::Status GRPCService::Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Execute(arg, result); }); -} - ::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, const ExecuteGraphRequest* arg, ExecuteResponse* result) { @@ -82,13 +54,6 @@ namespace xla { [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); } -::grpc::Status GRPCService::ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ExecuteAsync(arg, result); }); -} - ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { @@ -136,20 +101,6 @@ namespace xla { [this, arg, result]() { return service_->ResetDevice(arg, result); }); } -::grpc::Status GRPCService::IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->IsConstant(arg, result); }); -} - -::grpc::Status GRPCService::ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ComputeConstant(arg, result); }); -} - ::grpc::Status GRPCService::GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) { @@ -157,43 +108,4 @@ namespace xla { [this, arg, result]() { return service_->GetShape(arg, result); }); } -::grpc::Status GRPCService::GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationShape(arg, result); - }); -} - -::grpc::Status GRPCService::GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->GetLocalShape(arg, result); }); -} - -::grpc::Status GRPCService::GetComputationStats( - ::grpc::ServerContext* context, const ComputationStatsRequest* arg, - ComputationStatsResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationStats(arg, result); - }); -} - -::grpc::Status GRPCService::SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->SnapshotComputation(arg, result); - }); -} - -::grpc::Status GRPCService::LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->LoadComputationSnapshot(arg, result); - }); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index 50f02796f2d45baf894841782cd96d8d51a5ba00..ca1b09b648013ad45d806040c5ddcf11d9e5604e 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ #define TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ -#include "grpc++/server_context.h" +#include "grpcpp/server_context.h" #include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h" #include "tensorflow/compiler/xla/service/service.h" @@ -31,13 +31,6 @@ class GRPCService : public grpc::XlaService::Service { static StatusOr> NewService( se::Platform* platform = nullptr); - ::grpc::Status Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) override; - - ::grpc::Status CreateOp(::grpc::ServerContext* context, const OpRequest* arg, - OpResponse* result) override; - ::grpc::Status Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) override; @@ -46,22 +39,10 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - ::grpc::Status SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - ::grpc::Status Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) override; - ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, const ExecuteGraphRequest* arg, ExecuteResponse* result) override; - ::grpc::Status ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override; @@ -86,38 +67,10 @@ class GRPCService : public grpc::XlaService::Service { const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - ::grpc::Status IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) override; - - ::grpc::Status ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - ::grpc::Status GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) override; - ::grpc::Status GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ::grpc::Status GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - ::grpc::Status GetComputationStats(::grpc::ServerContext* context, - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - ::grpc::Status SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - ::grpc::Status LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; - private: std::unique_ptr<::xla::Service> service_; diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index e29908ccec80db76e3b5b856e57382c56430c379..c68c857c304138ff4318e243f66547c6acce1005 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -15,9 +15,9 @@ limitations under the License. // Basic server binary that exposes a xla::Service through a GRPC interface // on a configurable port. -#include "grpc++/security/server_credentials.h" -#include "grpc++/server.h" -#include "grpc++/server_builder.h" +#include "grpcpp/security/server_credentials.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_service.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index e1f2b0abe39b10dd82b700941748bc4f4e8cb2f8..7b8ab158e1396d7087a407be180ab44d2e16e121 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -20,82 +20,56 @@ namespace xla { GRPCStub::~GRPCStub() = default; -tensorflow::Status MakeRPC( +Status MakeRPC( const std::function<::grpc::Status(::grpc::ClientContext*)>& rpc_method) { ::grpc::ClientContext context; ::grpc::Status s = rpc_method(&context); return tensorflow::FromGrpcStatus(s); } -tensorflow::Status GRPCStub::TransferToClient( - const TransferToClientRequest* request, - TransferToClientResponse* response) { +Status GRPCStub::TransferToClient(const TransferToClientRequest* request, + TransferToClientResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToClient(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferToServer( - const TransferToServerRequest* request, - TransferToServerResponse* response) { +Status GRPCStub::TransferToServer(const TransferToServerRequest* request, + TransferToServerResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToServer(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferToInfeed( - const TransferToInfeedRequest* request, - TransferToInfeedResponse* response) { +Status GRPCStub::TransferToInfeed(const TransferToInfeedRequest* request, + TransferToInfeedResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToInfeed(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferFromOutfeed( - const TransferFromOutfeedRequest* request, - TransferFromOutfeedResponse* response) { +Status GRPCStub::TransferFromOutfeed(const TransferFromOutfeedRequest* request, + TransferFromOutfeedResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferFromOutfeed(context, *request, response); }); } -tensorflow::Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, - ResetDeviceResponse* response) { +Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, + ResetDeviceResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ResetDevice(context, *request, response); }); } -tensorflow::Status GRPCStub::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->LoadComputationSnapshot(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::Execute(const ExecuteRequest* request, - ExecuteResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Execute(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) { +Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ExecuteGraph(context, *request, response); }); } -tensorflow::Status GRPCStub::ExecuteParallel( - const ExecuteParallelRequest* request, ExecuteParallelResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteParallel(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::ExecuteGraphParallel( +Status GRPCStub::ExecuteGraphParallel( const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -103,38 +77,21 @@ tensorflow::Status GRPCStub::ExecuteGraphParallel( }); } -tensorflow::Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, - ExecuteAsyncResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteAsync(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::WaitForExecution( - const WaitForExecutionRequest* request, - WaitForExecutionResponse* response) { +Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request, + WaitForExecutionResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->WaitForExecution(context, *request, response); }); } -tensorflow::Status GRPCStub::DeconstructTuple( - const DeconstructTupleRequest* request, - DeconstructTupleResponse* response) { +Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request, + DeconstructTupleResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->DeconstructTuple(context, *request, response); }); } -tensorflow::Status GRPCStub::GetComputationStats( - const ComputationStatsRequest* request, - ComputationStatsResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationStats(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::GetComputationGraphStats( +Status GRPCStub::GetComputationGraphStats( const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -142,81 +99,28 @@ tensorflow::Status GRPCStub::GetComputationGraphStats( }); } -tensorflow::Status GRPCStub::GetComputationShape( - const GetComputationShapeRequest* request, - GetComputationShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationShape(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::GetShape(const GetShapeRequest* request, - GetShapeResponse* response) { +Status GRPCStub::GetShape(const GetShapeRequest* request, + GetShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetShape(context, *request, response); }); } -tensorflow::Status GRPCStub::GetDeviceHandles( - const GetDeviceHandlesRequest* request, - GetDeviceHandlesResponse* response) { +Status GRPCStub::GetDeviceHandles(const GetDeviceHandlesRequest* request, + GetDeviceHandlesResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetDeviceHandles(context, *request, response); }); } -tensorflow::Status GRPCStub::CreateChannelHandle( - const CreateChannelHandleRequest* request, - CreateChannelHandleResponse* response) { +Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request, + CreateChannelHandleResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->CreateChannelHandle(context, *request, response); }); } -// Methods used by ComputationBuilder. -tensorflow::Status GRPCStub::Computation(const ComputationRequest* request, - ComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Computation(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::Op(const OpRequest* request, - OpResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->CreateOp(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, - GetLocalShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetLocalShape(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::SetReturnValue( - const SetReturnValueRequest* request, SetReturnValueResponse* responses) { - return MakeRPC([this, request, responses](::grpc::ClientContext* context) { - return grpc_stub_->SetReturnValue(context, *request, responses); - }); -} - -tensorflow::Status GRPCStub::IsConstant(const IsConstantRequest* request, - IsConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->IsConstant(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::ComputeConstant( - const ComputeConstantRequest* request, ComputeConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ComputeConstant(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::ComputeConstantGraph( +Status GRPCStub::ComputeConstantGraph( const ComputeConstantGraphRequest* request, ComputeConstantResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -224,18 +128,9 @@ tensorflow::Status GRPCStub::ComputeConstantGraph( }); } -// Methods used by Computation. -tensorflow::Status GRPCStub::SnapshotComputation( - const SnapshotComputationRequest* request, - SnapshotComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->SnapshotComputation(context, *request, response); - }); -} - // Methods used by GlobalData. -tensorflow::Status GRPCStub::Unregister(const UnregisterRequest* request, - UnregisterResponse* response) { +Status GRPCStub::Unregister(const UnregisterRequest* request, + UnregisterResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->Unregister(context, *request, response); }); diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index fd9810d4f1a5e084b73e83007ea7f9f8b0462c72..8dfcb761387d608abbb1f62974f49b976a7ff7ff 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -28,105 +28,51 @@ class GRPCStub : public ServiceInterface { explicit GRPCStub(grpc::XlaService::Stub* stub) : grpc_stub_(stub) {} ~GRPCStub() override; - tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, - TransferToClientResponse* result) override; + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override; + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override; + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override; + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override; + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; - tensorflow::Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) override; + Status ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) override; - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override; + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, + ExecuteParallelResponse* response) override; - tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) override; + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; - tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* request, - ExecuteParallelResponse* response) override; + Status GetComputationGraphStats(const ComputationGraphStatsRequest* request, + ComputationStatsResponse* response) override; - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override; + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; - tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) override; + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; - tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - tensorflow::Status GetComputationGraphStats( - const ComputationGraphStatsRequest* request, - ComputationStatsResponse* response) override; - - tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) override; - - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override; - - tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) override; - - // Methods used by ComputationBuilder. - tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - - tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; - tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; - - tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - - tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) override; - - // Methods used by Computation. - tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) override; + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Methods used by GlobalData. - tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) override; + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; grpc::XlaService::Stub* service() { return grpc_stub_; } diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index c47164ee1b7657ae378a053f553442bee751753e..551ae895e05586daec0ffcd425f4950f76bdd50d 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -75,19 +75,7 @@ service XlaService { rpc GetShape(GetShapeRequest) returns (GetShapeResponse) { } - // Requests the program shape of the referenced computation. - rpc GetComputationShape(GetComputationShapeRequest) - returns (GetComputationShapeResponse) { - } - - // Requests the statistics of the given computation. - rpc GetComputationStats(ComputationStatsRequest) - returns (ComputationStatsResponse) { - } - // Requests the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc GetComputationGraphStats(ComputationGraphStatsRequest) returns (ComputationStatsResponse) { } @@ -121,25 +109,12 @@ service XlaService { rpc ResetDevice(ResetDeviceRequest) returns (ResetDeviceResponse) { } - // Tests if an expression is a compile-time constant. - rpc IsConstant(IsConstantRequest) returns (IsConstantResponse) { - } - - // Computes the value of a constant expression. - rpc ComputeConstant(ComputeConstantRequest) - returns (ComputeConstantResponse) { - } - // Computes the value of a constant expression. The request contains the // computation graph for the constant expression. rpc ComputeConstantGraph(ComputeConstantGraphRequest) returns (ComputeConstantResponse) { } - // Retrieves the inferred shape for a value within a computation. - rpc GetLocalShape(GetLocalShapeRequest) returns (GetLocalShapeResponse) { - } - // Requests one or more device handles from the target. The returned device // handles can be used to specify the device on which to execute computations // or transfer data. @@ -153,32 +128,6 @@ service XlaService { returns (CreateChannelHandleResponse) { } - // Requests that the referenced computation be specialized for the provided - // arguments for subsequent execution. This permits things such as value - // specialization. - rpc Specialize(SpecializeRequest) returns (SpecializeResponse) { - } - - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) { - } - - // Computation creates a new computation with the given name. - // A unique ComputationHandle is returned. - rpc Computation(ComputationRequest) returns (ComputationResponse) { - } - - // Adds a new op to a computation. - rpc CreateOp(OpRequest) returns (OpResponse) { - } - - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - rpc Execute(ExecuteRequest) returns (ExecuteResponse) { - } - // Invokes the provided computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. @@ -188,38 +137,13 @@ service XlaService { // Invokes the provided list of computations in parallel with the provided // global data for each computation. Returns a list of global data output and // execution timing. - rpc ExecuteParallel(ExecuteParallelRequest) - returns (ExecuteParallelResponse) { - } - - // Invokes the provided list of computations in parallel with the provided - // global data for each computation. Returns a list of global data output and - // execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc ExecuteGraphParallel(ExecuteGraphParallelRequest) returns (ExecuteParallelResponse) { } - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - rpc ExecuteAsync(ExecuteAsyncRequest) returns (ExecuteAsyncResponse) { - } - // Waits until the given execution (aysnchronously launched) is complete, and // returns the global data output. rpc WaitForExecution(WaitForExecutionRequest) returns (WaitForExecutionResponse) { } - - // Serializes a computation to proto form, so it can be loaded via - // LoadComputationSnapshot. - rpc SnapshotComputation(SnapshotComputationRequest) - returns (SnapshotComputationResponse) { - } - - // Loads a computation from a captured snapshot. - rpc LoadComputationSnapshot(LoadComputationSnapshotRequest) - returns (LoadComputationSnapshotResponse) { - } } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index aa3a6261e0117c4c2e5c745d6851142b22a62a07..2942edbf71f29304ebb240f0547808ae0af1ac87 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -12,22 +12,26 @@ package_group( ], ) +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) xla_proto_library( - name = "session_proto", - srcs = ["session.proto"], + name = "hlo_proto", + srcs = ["hlo.proto"], visibility = ["//visibility:public"], deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) -xla_proto_library( - name = "hlo_proto", +tf_proto_library_py( + name = "hlo_proto", # bzl adds a _py suffix only to the OSS target. srcs = ["hlo.proto"], visibility = ["//visibility:public"], - deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) xla_proto_library( @@ -265,6 +269,7 @@ cc_library( "dfs_hlo_visitor.cc", "hlo_computation.cc", "hlo_instruction.cc", + "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", "hlo_sharding.cc", @@ -272,18 +277,21 @@ cc_library( hdrs = [ "dfs_hlo_visitor.h", "dfs_hlo_visitor_with_default.h", + "hlo_clone_context.h", "hlo_computation.h", + "hlo_domain_metadata.h", "hlo_instruction.h", + "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", "hlo_sharding.h", ], deps = [ + ":hlo_casting_utils", ":hlo_module_config", ":hlo_proto", ":hlo_reachability", ":name_uniquer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -296,6 +304,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], @@ -335,8 +344,8 @@ tf_cc_test( ":hlo", ":pattern_matcher", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) @@ -374,6 +383,7 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -384,20 +394,9 @@ tf_cc_test( srcs = ["hlo_matchers_test.cc"], deps = [ ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", - ], -) - -cc_library( - name = "versioned_computation_handle", - srcs = ["versioned_computation_handle.cc"], - hdrs = ["versioned_computation_handle.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", ], ) @@ -406,12 +405,14 @@ tf_cc_test( srcs = ["hlo_instruction_test.cc"], deps = [ ":hlo", + ":hlo_parser", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -428,6 +429,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -532,45 +534,6 @@ tf_cc_test( ], ) -cc_library( - name = "user_computation", - srcs = ["user_computation.cc"], - hdrs = ["user_computation.h"], - deps = [ - ":hlo", - ":session_proto", - ":shape_inference", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "user_computation_test", - srcs = ["user_computation_test.cc"], - deps = [ - ":hlo_matchers", - ":user_computation", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - cc_library( name = "platform_util", srcs = ["platform_util.cc"], @@ -616,10 +579,8 @@ cc_library( ":allocation_tracker", ":backend", ":channel_tracker", - ":compilation_cache", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":execution_tracker", @@ -630,11 +591,8 @@ cc_library( ":hlo_module_config", ":hlo_proto_util", ":platform_util", - ":session_proto", ":source_map_util", ":transfer_manager", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:service_interface", @@ -661,7 +619,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":hlo", @@ -670,8 +627,6 @@ cc_library( ":platform_util", ":service", ":shaped_buffer", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -695,7 +650,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":platform_util", ":service", "//tensorflow/compiler/xla:status_macros", @@ -760,6 +714,23 @@ cc_library( ], ) +tf_cc_test( + name = "shaped_buffer_test", + srcs = ["shaped_buffer_test.cc"], + deps = [ + ":cpu_plugin", + ":device_memory_allocator", + ":platform_util", + ":shaped_buffer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:ptr_util", + "//tensorflow/core:test", + ], +) + cc_library( name = "executable", srcs = ["executable.cc"], @@ -775,9 +746,7 @@ cc_library( ":hlo_graph_dumper", ":hlo_proto", ":pool", - ":session_proto", ":shaped_buffer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -854,7 +823,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -874,34 +842,12 @@ cc_library( ], ) -cc_library( - name = "computation_tracker", - srcs = ["computation_tracker.cc"], - hdrs = ["computation_tracker.h"], - deps = [ - ":hlo", - ":hlo_module_config", - ":session_proto", - ":user_computation", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "channel_tracker", srcs = ["channel_tracker.cc"], hdrs = ["channel_tracker.h"], deps = [ ":hlo", - ":session_proto", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -935,33 +881,6 @@ tf_cc_test( ], ) -cc_library( - name = "liveness_util", - srcs = ["liveness_util.cc"], - hdrs = ["liveness_util.h"], - deps = [ - ":hlo", - ":hlo_dataflow_analysis", - ":logical_buffer", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - ], -) - -tf_cc_test( - name = "liveness_util_test", - srcs = ["liveness_util_test.cc"], - deps = [ - ":hlo", - ":liveness_util", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - ], -) - cc_library( name = "buffer_liveness", srcs = [ @@ -973,7 +892,6 @@ cc_library( deps = [ ":hlo", ":hlo_ordering", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1010,6 +928,7 @@ cc_library( ], deps = [ ":buffer_liveness", + ":buffer_value_containers", ":heap_simulator", ":hlo", ":hlo_proto", @@ -1034,7 +953,6 @@ tf_cc_test( ":buffer_assignment", ":buffer_value", ":call_graph", - ":computation_tracker", ":copy_insertion", ":cpu_plugin", ":flatten_call_graph", @@ -1048,9 +966,9 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -1065,7 +983,6 @@ cc_library( ":hlo_dataflow_analysis", ":hlo_proto", ":hlo_value", - ":liveness_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1087,9 +1004,9 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1098,11 +1015,11 @@ cc_library( srcs = ["heap_simulator.cc"], hdrs = ["heap_simulator.h"], deps = [ + ":buffer_value", + ":buffer_value_containers", ":hlo", ":hlo_ordering", ":hlo_proto", - ":liveness_util", - ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1118,7 +1035,7 @@ tf_cc_test( ":heap_simulator", ":hlo", ":hlo_ordering", - ":logical_buffer", + ":hlo_value", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", @@ -1190,9 +1107,9 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1225,9 +1142,22 @@ tf_cc_test( deps = [ ":hlo_matchers", ":instruction_fusion", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", ], ) @@ -1268,13 +1198,11 @@ cc_library( deps = [ ":hlo", ":hlo_pass", - ":hlo_query", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], @@ -1359,15 +1287,51 @@ tf_cc_test( ], ) +cc_library( + name = "batch_dot_simplification", + srcs = ["batch_dot_simplification.cc"], + hdrs = ["batch_dot_simplification.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "batch_dot_simplification_test", + srcs = ["batch_dot_simplification_test.cc"], + deps = [ + ":batch_dot_simplification", + ":hlo", + ":hlo_matchers", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "gather_expander_test", srcs = ["gather_expander_test.cc"], deps = [ ":gather_expander", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1673,14 +1637,11 @@ tf_cc_test( name = "hlo_cost_analysis_test", srcs = ["hlo_cost_analysis_test.cc"], deps = [ - ":computation_tracker", ":cpu_plugin", ":hlo", ":hlo_cost_analysis", ":local_service", ":service", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", @@ -1719,8 +1680,10 @@ tf_cc_test( ":cpu_plugin", ":hlo_cost_analysis", ":hlo_execution_profile", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) @@ -1785,6 +1748,17 @@ cc_library( ], ) +cc_library( + name = "buffer_value_containers", + hdrs = ["buffer_value_containers.h"], + deps = [ + ":buffer_value", + ":logical_buffer", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + cc_library( name = "logical_buffer", srcs = ["logical_buffer.cc"], @@ -1859,6 +1833,44 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_liveness_analysis", + srcs = ["hlo_liveness_analysis.cc"], + hdrs = ["hlo_liveness_analysis.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_value", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_liveness_analysis_test", + srcs = ["hlo_liveness_analysis_test.cc"], + deps = [ + ":hlo", + ":hlo_liveness_analysis", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_buffer", srcs = ["hlo_buffer.cc"], @@ -1970,20 +1982,6 @@ tf_cc_test( ], ) -cc_library( - name = "compilation_cache", - srcs = ["compilation_cache.cc"], - hdrs = ["compilation_cache.h"], - deps = [ - ":executable", - ":hlo_module_config", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "layout_assignment", srcs = [ @@ -1995,10 +1993,12 @@ cc_library( deps = [ ":computation_layout", ":hlo", + ":hlo_dce", ":hlo_graph_dumper", ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -2022,7 +2022,6 @@ cc_library( ":hlo_graph_dumper", ":hlo_ordering", ":hlo_pass", - ":liveness_util", ":logical_buffer", ":tuple_simplifier", "//tensorflow/compiler/xla:status_macros", @@ -2069,6 +2068,24 @@ cc_library( ], ) +cc_library( + name = "hlo_module_dce", + srcs = ["hlo_module_dce.cc"], + hdrs = ["hlo_module_dce.h"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_liveness_analysis", + ":hlo_pass", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_verifier", srcs = ["hlo_verifier.cc"], @@ -2111,7 +2128,6 @@ cc_library( ":hlo_dce", ":hlo_ordering", ":hlo_scheduling", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -2136,6 +2152,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -2159,6 +2176,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "hlo_module_dce_test", + srcs = ["hlo_module_dce_test.cc"], + deps = [ + ":hlo", + ":hlo_module_dce", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "layout_assignment_test", srcs = ["layout_assignment_test.cc"], @@ -2175,9 +2213,9 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -2225,6 +2263,7 @@ cc_library( hdrs = ["hlo_cse.h"], deps = [ ":hlo", + ":hlo_domain_map", ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -2247,6 +2286,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -2288,6 +2328,78 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_domain_map", + srcs = ["hlo_domain_map.cc"], + hdrs = ["hlo_domain_map.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_sharding_metadata", + srcs = ["hlo_sharding_metadata.cc"], + hdrs = [ + "hlo_sharding_metadata.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_domain_isolator", + srcs = ["hlo_domain_isolator.cc"], + hdrs = ["hlo_domain_isolator.h"], + deps = [ + ":hlo", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + ], +) + +cc_library( + name = "hlo_domain_remover", + srcs = ["hlo_domain_remover.cc"], + hdrs = ["hlo_domain_remover.h"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_map", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_domain_test", + srcs = ["hlo_domain_test.cc"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_remover", + ":hlo_parser", + ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_element_type_converter", srcs = ["hlo_element_type_converter.cc"], @@ -2316,8 +2428,14 @@ tf_cc_test( cc_library( name = "device_memory_allocator", - srcs = ["device_memory_allocator.cc"], - hdrs = ["device_memory_allocator.h"], + srcs = [ + "device_memory_allocator.cc", + "owning_device_memory.cc", + ], + hdrs = [ + "device_memory_allocator.h", + "owning_device_memory.h", + ], deps = [ "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -2352,6 +2470,24 @@ cc_library( ], ) +xla_test( + name = "elemental_ir_emitter_test", + srcs = ["elemental_ir_emitter_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "hlo_module_config", srcs = ["hlo_module_config.cc"], @@ -2406,7 +2542,6 @@ cc_library( name = "hlo_tfgraph_builder", srcs = ["hlo_tfgraph_builder.cc"], hdrs = ["hlo_tfgraph_builder.h"], - visibility = ["//tensorflow/compiler/xla/tools:__pkg__"], deps = [ ":hlo", "//tensorflow/compiler/xla:literal_util", @@ -2437,6 +2572,7 @@ cc_library( hdrs = ["hlo_graph_dumper.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", "//tensorflow/compiler/xla:literal_util", @@ -2475,6 +2611,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) @@ -2493,10 +2630,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -2529,7 +2666,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -2634,12 +2770,11 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", - "@com_google_absl//absl/memory", ], ) @@ -2671,8 +2806,8 @@ tf_cc_test( ":tuple_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2695,9 +2830,10 @@ tf_cc_test( deps = [ ":while_util", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2723,6 +2859,7 @@ tf_cc_test( ":hlo_matchers", ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], @@ -2749,8 +2886,8 @@ tf_cc_test( ":hlo_matchers", ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) @@ -2781,3 +2918,97 @@ cc_library( "//tensorflow/core:lib", ], ) + +cc_library( + name = "indexed_array_analysis", + srcs = ["indexed_array_analysis.cc"], + hdrs = ["indexed_array_analysis.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", + ], +) + +tf_cc_test( + name = "indexed_array_analysis_test", + srcs = ["indexed_array_analysis_test.cc"], + deps = [ + ":hlo_matchers", + ":indexed_array_analysis", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo", + ":hlo_lexer", + ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_parser", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + "hlo_token.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) + +cc_library( + name = "hlo_casting_utils", + hdrs = ["hlo_casting_utils.h"], + deps = ["//tensorflow/core:lib"], +) + +tf_cc_test( + name = "hlo_casting_utils_test", + srcs = ["hlo_casting_utils_test.cc"], + deps = [ + ":hlo", + ":hlo_casting_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4ec79a024463b5129cc8687235e673f9ea95959d..3b36939b8a6900f047bbec225aa232e0e805b5d1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -92,26 +92,6 @@ bool ReshapeIsBitcast( valid_bitcast_callback(operand->shape(), reshape->shape()); } -// Adds a scalar computation to the module to enable optimizations with dot -// converting into reduction. -HloComputation* CreateScalarBinaryComputation(HloModule* module, - PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); - HloComputation* scalar_computation = - module->AddEmbeddedComputation(b.Build(scalar_op)); - return scalar_computation; -} - -} // namespace - // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain // algebraic expressions to simplified forms. Note: This only supports // simplifications that simply look at the operands of an instruction. For the @@ -177,6 +157,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleSubtract(HloInstruction* sub) override; + Status HandleMap(HloInstruction* map) override; + Status HandleMaximum(HloInstruction* maximum) override; Status HandleMinimum(HloInstruction* minimum) override; @@ -220,8 +202,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloComputation* AddReduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); + HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( shape, hlo, zero, {dim}, AddReduce_computation)); @@ -252,10 +233,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand); - // A Reshape or Broadcast that feeds an element-wise operation with a unique - // non-scalar operand can sink to after the operation. - StatusOr TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast); + // A Broadcast that feeds an element-wise operation with a unique non-scalar + // operand can sink to after the operation. + StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast); // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -293,6 +274,24 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr OptimizeDotOfGather(HloInstruction* dot); + HloComputation* GetOrCreateScalarAddComputation() { + if (scalar_add_computation_) { + return scalar_add_computation_; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(F32, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); + scalar_add_computation_ = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return scalar_add_computation_; + } + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -311,8 +310,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable convolution simplification on platforms where it causes a slowdown. bool enable_conv_simplification_; + + // Cached computation for adding two scalar F32. + HloComputation* scalar_add_computation_ = nullptr; }; +} // namespace + bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, @@ -501,13 +505,13 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } static HloInstruction* BuildTupleConstant(HloComputation* computation, - const Literal& literal) { + const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { elems.push_back( - BuildTupleConstant(computation, LiteralView::Create(literal, {i}))); + BuildTupleConstant(computation, LiteralSlice(literal, {i}))); } return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { @@ -1301,7 +1305,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // broadcast after the unary element-wise operation. TF_ASSIGN_OR_RETURN( bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); + TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); changed_ |= sink_succeeded; if (sink_succeeded) { return Status::OK(); @@ -1553,15 +1557,16 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { return Status::OK(); } -StatusOr AlgebraicSimplifierVisitor:: - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast) { +StatusOr +AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast) { + TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast); bool changed = false; - if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) { + if (ShapeUtil::IsScalar(broadcast->shape())) { return false; } - HloInstruction* operand = reshape_or_broadcast->mutable_operand(0); - for (HloInstruction* user : reshape_or_broadcast->users()) { + HloInstruction* operand = broadcast->mutable_operand(0); + for (HloInstruction* user : broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { continue; } @@ -1579,55 +1584,50 @@ StatusOr AlgebraicSimplifierVisitor:: continue; } - int64 reshape_or_broadcast_operand_index = -1; // Find the unique non-scalar operand or continue if there isn't one. - int64 scalar_count = 0; - for (int64 i = 0; i < user->operand_count(); ++i) { - if (ShapeUtil::IsScalar(user->operand(i)->shape())) { - ++scalar_count; - } else { - reshape_or_broadcast_operand_index = i; + int64 scalar_broadcast_count = 0; + int64 broadcast_use_count = 0; + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + ++scalar_broadcast_count; + } else if (broadcast == user_operand) { + ++broadcast_use_count; } } - if (scalar_count != user->operand_count() - 1) { + if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) { continue; } - VLOG(4) << "Sinking reshape or broadcast after user:"; - VLOG(4) << " old reshape/broadcast: " << reshape_or_broadcast->ToString(); + std::vector new_operands; + new_operands.reserve(user->operand_count()); + + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + new_operands.push_back( + computation_->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType( + operand->shape(), user_operand->shape().element_type()), + user_operand->mutable_operand(0), {}))); + } else { + CHECK_EQ(broadcast, user_operand); + new_operands.push_back(operand); + } + } + VLOG(4) << "Sinking broadcast after user:"; + VLOG(4) << " old broadcast: " << broadcast->ToString(); VLOG(4) << " old user: " << user->ToString(); - CHECK_EQ(user->operand(reshape_or_broadcast_operand_index), - reshape_or_broadcast); - auto new_user_operands = user->operands(); - new_user_operands[reshape_or_broadcast_operand_index] = operand; - auto new_user = computation_->AddInstruction(user->CloneWithNewOperands( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(operand->shape().dimensions()), - LayoutUtil::MinorToMajor(operand->shape())), - new_user_operands)); + HloInstruction* new_user = + computation_->AddInstruction(user->CloneWithNewOperands( + ShapeUtil::ChangeElementType(operand->shape(), + user->shape().element_type()), + new_operands)); VLOG(4) << " new user: " << new_user->ToString(); - HloInstruction* new_reshape_or_broadcast = nullptr; - if (reshape_or_broadcast->opcode() == HloOpcode::kReshape) { - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user)); - } else { - TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast); - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user, reshape_or_broadcast->dimensions())); - } - VLOG(4) << " new reshape/broadcast: " - << new_reshape_or_broadcast->ToString(); - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_reshape_or_broadcast)); + HloInstruction* new_broadcast = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + user->shape(), new_user, broadcast->dimensions())); + VLOG(4) << " new broadcast: " << new_broadcast->ToString(); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast)); changed = true; } return changed; @@ -1670,16 +1670,6 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } } - // A Reshape that feeds a unary element-wise operation can sink the - // reshape after the unary element-wise operation. - TF_ASSIGN_OR_RETURN( - bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape)); - changed_ |= sink_succeeded; - if (sink_succeeded) { - return Status::OK(); - } - // Make this a bitcast if possible. if (is_layout_sensitive_ && ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { @@ -1784,6 +1774,46 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } + // If the reduction results in the same number of elements, then the only + // possible side effect would be a reshape. Since the init_value is an + // identity of the reduction function, we can therefore replace the reduce + // with a simple reshape, ignoring the reduction function completely. + if (ShapeUtil::ElementsIn(reduce->shape()) == + ShapeUtil::ElementsIn(arg->shape())) { + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); + } + + // If a reduce feeds a reduce with the same computation and initial value, + // they can be combined into a single reduce. + if (arg->opcode() == HloOpcode::kReduce && + init_value->Identical(*arg->operand(1)) && + *function == *arg->to_apply()) { + // Create a new reduce with the combined reduction dimensions of both + // reduces. + std::vector arg_dims = arg->dimensions(); + std::sort(arg_dims.begin(), arg_dims.end()); + std::vector reduce_dims = reduce->dimensions(); + std::sort(reduce_dims.begin(), reduce_dims.end()); + // Transform reduce_dims to the same rank as the operand of the operand. + for (int64 arg_dim : arg_dims) { + for (int64& dim : reduce_dims) { + if (dim >= arg_dim) { + ++dim; + } + } + } + std::vector new_dimensions; + new_dimensions.reserve(arg->dimensions().size() + + reduce->dimensions().size()); + std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), + reduce_dims.end(), std::back_inserter(new_dimensions)); + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0), + init_value, new_dimensions, function)); + } + // A reshape that collapses multiple dimensions into a dimension being // reduced can just reduce all of those dimensions instead of doing a // collapsing reshape before a reduction. @@ -1828,15 +1858,6 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } } - if (ShapeUtil::ElementsIn(reduce->shape()) == - ShapeUtil::ElementsIn(arg->shape()) || - ShapeUtil::HasZeroElements(arg->shape())) { - auto reshape = computation_->AddInstruction( - HloInstruction::CreateReshape(reduce->shape(), arg)); - return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateMap(reduce->shape(), - {init_value, reshape}, function)); - } return Status::OK(); } @@ -1856,7 +1877,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateMap(reduce_window->shape(), - {operand, reduce_window->mutable_operand(1)}, + {reduce_window->mutable_operand(1), operand}, function)); } @@ -2186,6 +2207,39 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( return true; } +Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { + auto* map_computation = map->to_apply(); + auto* map_root = map_computation->root_instruction(); + if (map_root->opcode() == HloOpcode::kParameter) { + ReplaceInstructionIfSameShape( + map, map->mutable_operand(map_root->parameter_number())); + return Status::OK(); + } + if (map_root->opcode() == HloOpcode::kConstant) { + if (!ShapeUtil::IsScalar(map_root->shape())) { + return Status::OK(); + } + auto clone = map_root->CloneWithNewOperands(map_root->shape(), {}); + if (ShapeUtil::IsScalar(map->shape())) { + return ReplaceWithNewInstruction(map, std::move(clone)); + } + return ReplaceWithNewInstruction( + map, + HloInstruction::CreateBroadcast( + map->shape(), computation_->AddInstruction(std::move(clone)), {})); + } + std::vector new_operands; + for (auto* root_operand : map_root->operands()) { + if (root_operand->opcode() != HloOpcode::kParameter) { + return Status::OK(); + } + new_operands.push_back( + map->mutable_operand(root_operand->parameter_number())); + } + auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands); + return ReplaceWithNewInstruction(map, std::move(clone)); +} + Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { // Match the following tree: // min_operand operand diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 4e082877c776c35bab499c805fef7632765a3ee1..2605b0488cb7c6850746df94c4ab05d6b5d35de5 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -74,6 +74,44 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +// Test that Reduce(Reduce(A)) -> Reduce(A) +TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { + HloComputation::Builder builder(TestName()); + // Create add computation. + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r4f32, "param")); + std::vector dims0({0}); + Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7}); + HloInstruction* reduce0 = builder.AddInstruction( + HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation)); + std::vector dims1({1, 2}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); + builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, + dims1, add_computation)); + module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reduce(param, zero)); + EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); +} + // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -143,6 +181,39 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { + HloComputation::Builder builder(TestName()); + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + builder.AddInstruction( + HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation)); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kMap); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, zero)); +} + TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); @@ -1318,32 +1389,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); } -TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { - HloComputation::Builder builder(TestName()); - HloInstruction* param = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param")); - HloInstruction* movable_reshape = - builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param)); - HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), - HloOpcode::kMaximum, movable_reshape, zero)); - auto computation = module().AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - - simplifier.Run(&module()).ValueOrDie(); - EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Maximum(param, zero))); -} - // Regression test for a bug in the reshape sinking transformation, where // moving a reshape to a scalar led to a crash. TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { @@ -1707,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1752,7 +1797,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1774,7 +1819,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1797,7 +1842,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1925,7 +1970,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, @@ -2053,7 +2099,7 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2083,7 +2129,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2114,7 +2160,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2144,7 +2190,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); @@ -2177,7 +2223,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2193,10 +2239,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction::CreateParameter(0, r0f32, "scalar_param")); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, scalar_param, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {})); Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( @@ -2212,10 +2256,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2230,10 +2274,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, forty_two, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {})); HloInstruction* transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -2252,7 +2294,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2261,7 +2303,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2342,7 +2385,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2437,7 +2481,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index cf1231bcce4d004284b71a49063e3e470a9eb93f..95b4cb6d2e694063b648b264bd2454ae0a5469ff 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -101,7 +101,7 @@ StatusOr AllocationTracker::RegisterInternal( return result; } -tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { +Status AllocationTracker::Unregister(const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Unregister(" << "handle: " << data.handle() << ")"; @@ -130,7 +130,7 @@ tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { for (auto& shaped_buffer : it->second) { shaped_buffer.reset(); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr> AllocationTracker::DeconstructTuple( @@ -220,8 +220,10 @@ void AllocationTracker::AddAllocationOrIncrementRefCount( AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; auto it = allocation_map.find(device_memory.opaque()); if (it == allocation_map.end()) { - allocation_map[device_memory.opaque()] = {device_memory, device_ordinal, - /*ref_count=*/1}; + allocation_map[device_memory.opaque()] = { + OwningDeviceMemory(device_memory, device_ordinal, + backend_->memory_allocator()), + /*ref_count=*/1}; } else { it->second.ref_count++; } @@ -235,13 +237,12 @@ Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory, Allocation& allocation = it->second; TF_RET_CHECK(allocation.ref_count >= 1); if (allocation.ref_count == 1) { - TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate( - device_ordinal, &device_memory)); + allocation.device_memory.Free(); allocation_map.erase(it); } else { allocation.ref_count--; } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 1174fa641c06ae053bcc652416bfbc30cabc777c..a7d8927cf7e90d764ff8046df16c71922b11478e 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -76,10 +76,7 @@ class AllocationTracker { // Data structure encapsulating single memory allocation on the device. struct Allocation { // The pointer to this allocation. - se::DeviceMemoryBase device_memory; - - // The device that the memory is allocated on. - int device_ordinal; + OwningDeviceMemory device_memory; // This is the number of times this memory allocation is referred to by // registered data handles. @@ -126,7 +123,10 @@ class AllocationTracker { int64 next_handle_ GUARDED_BY(mutex_); // A map from device ordinal to AllocationMap. - tensorflow::gtl::FlatMap opaque_to_allocation_map_ + // + // This is not a TF FlatMap because (currently) FlatMap (and therefore + // AllocationMap) is not movable. + std::unordered_map opaque_to_allocation_map_ GUARDED_BY(mutex_); // A map from data handle to a vector of shaped buffers that represent the diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc new file mode 100644 index 0000000000000000000000000000000000000000..2099916509acdbc2680cc2b5bd405e96f2f7bfb8 --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" + +namespace xla { +StatusOr +BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( + HloInstruction* batch_dot) { + const DotDimensionNumbers& dim_numbers = batch_dot->dot_dimension_numbers(); + HloInstruction *lhs = batch_dot->mutable_operand(0), + *rhs = batch_dot->mutable_operand(1); + const Shape& lhs_shape = lhs->shape(); + + std::vector degenerate_dims; + for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) { + if (lhs_shape.dimensions(batch_dim) == 1) { + degenerate_dims.push_back(batch_dim); + } + } + + if (degenerate_dims.empty()) { + return false; + } + + TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs, + ElideDegenerateDims(lhs, degenerate_dims)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs, + ElideDegenerateDims(rhs, degenerate_dims)); + + DotDimensionNumbers new_dim_numbers = dim_numbers; + new_dim_numbers.clear_lhs_batch_dimensions(); + new_dim_numbers.clear_rhs_batch_dimensions(); + + for (int64 i = 0, e = dim_numbers.lhs_batch_dimensions_size() - + degenerate_dims.size(); + i < e; i++) { + new_dim_numbers.add_lhs_batch_dimensions(i); + new_dim_numbers.add_rhs_batch_dimensions(i); + } + + new_dim_numbers.set_lhs_contracting_dimensions( + 0, + new_dim_numbers.lhs_contracting_dimensions(0) - degenerate_dims.size()); + new_dim_numbers.set_rhs_contracting_dimensions( + 0, + new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, + MakeReshapeHlo(batch_dot->shape(), new_dot)); + + VLOG(2) << "Replaced " << batch_dot->ToString() << " with " + << new_dot->ToString(); + + TF_RETURN_IF_ERROR( + batch_dot->parent()->ReplaceInstruction(batch_dot, new_dot_reshaped)); + + return true; +} + +tensorflow::StringPiece BatchDotSimplification::name() const { + return "batch-dot-simplification"; +} + +StatusOr BatchDotSimplification::Run(HloModule* module) { + bool changed = false; + std::vector dot_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), + [](HloInstruction* instr) { + return instr->opcode() == HloOpcode::kDot; + }); + } + for (HloInstruction* dot_instr : dot_instrs) { + TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, + ElideDegenerateBatchDimensionFromBatchDot(dot_instr)); + changed |= elided_batch_dim_from_one; + } + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h new file mode 100644 index 0000000000000000000000000000000000000000..c0ca8d8ebac1a3b218e7bd4d6db02b69cfb6916f --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +// Simplifies batch dot operations. +// +// Normally these would live in the algebraic simplifier, but we want to run +// this to fixpoint (this pass reaches fixed point in one execution) before we +// run the DotDecomposer. +class BatchDotSimplification : public HloPassInterface { + public: + StatusOr Run(HloModule* module) override; + tensorflow::StringPiece name() const override; + + private: + StatusOr ElideDegenerateBatchDimensionFromBatchDot( + HloInstruction* batch_dot); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..38f1a5d3a645f98220ec445bb9bbdf2b9b842109 --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -0,0 +1,168 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class BatchDotSimplificationTest : public HloVerifiedTestBase {}; + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_VectorVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,3] parameter(0) + b = f32[1,3] parameter(1) + ROOT dot = f32[1] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_MatrixVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,9,3] parameter(0) + b = f32[1,3] parameter(1) + ROOT dot = f32[1,9] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_MatrixMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,9,3] parameter(0) + b = f32[1,3,7] parameter(1) + ROOT dot = f32[1,9,7] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_VectorVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,3] parameter(0) + b = f32[9,1,7,1,3] parameter(1) + ROOT dot = f32[9,1,7,1] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={4} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/2))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_VectorMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,3] parameter(0) + b = f32[9,1,7,1,20,3] parameter(1) + ROOT dot = f32[9,1,7,1,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={5} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/3))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_MatrixMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,19,3] parameter(0) + b = f32[9,1,7,1,3,20] parameter(1) + ROOT dot = f32[9,1,7,1,19,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={5}, rhs_contracting_dims={4} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 38086bd7e121847be6b6b69415cfe87814e7fc24..ec13fadbc75e2315d1d6ef72e24a0faca0c7de40 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -15,35 +15,32 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batchnorm_expander.h" -#include #include -#include -#include #include #include #include -#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { @@ -61,8 +58,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run(HloComputation* computation, bool rewrite_training_op, - bool rewrite_inference_op, bool rewrite_grad_op, - bool use_fusion); + bool rewrite_inference_op, bool rewrite_grad_op); // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } @@ -73,37 +69,55 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { explicit BatchNormExpanderVisitor(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) + bool rewrite_grad_op) : computation_(computation), rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} - - HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); + rewrite_grad_op_(rewrite_grad_op) {} + + HloComputation* GetOrCreateScalarAddComputation( + PrimitiveType primitive_type) { + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(primitive_type, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); } - // Current HloComputation instance the BatchNormExpander is - // traversing. - HloComputation* computation_; - - bool rewrite_training_op_; - bool rewrite_inference_op_; - bool rewrite_grad_op_; - bool use_fusion_; + std::unique_ptr Rsqrt( + HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(-0.5f))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, + operand, exponent); + } - // Whether rewrite has occurred. - bool changed_ = false; + std::unique_ptr Mean( + int64 element_count, HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* elem_count_recip = + add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(1.0 / element_count))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, + operand, elem_count_recip); + } // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -127,18 +141,29 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { changed_ = true; return Status::OK(); } + // Current HloComputation instance the BatchNormExpander is + // traversing. + HloComputation* computation_; + + bool rewrite_training_op_; + bool rewrite_inference_op_; + bool rewrite_grad_op_; + + // Whether rewrite has occurred. + bool changed_ = false; }; +} // namespace + bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) { + bool rewrite_grad_op) { BatchNormExpanderVisitor visitor( computation, /*rewrite_training_op=*/rewrite_training_op, /*rewrite_inference_op=*/rewrite_inference_op, - /*rewrite_grad_op=*/rewrite_grad_op, - /*use_fusion=*/use_fusion); + /*rewrite_grad_op=*/rewrite_grad_op); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -156,6 +181,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); // Expand batch norm training into smaller HLO ops. @@ -165,12 +194,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( int64 feature_index = batch_norm->feature_index(); const int64 feature_count = operand_shape.dimensions(feature_index); const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + int64 elements_per_feature_int64 = size_in_elements / feature_count; HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -182,8 +206,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = - add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = add(HloInstruction::CreateBroadcast( + operand_shape, + add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); std::vector dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -199,11 +224,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // X^2. - auto operand_squared = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = + add_binary(operand_shape, HloOpcode::kMultiply, operand, operand); // Sum[X]. auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, dimensions_without_feature, @@ -214,71 +239,48 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); - // Fuse two parallel reduces together to improve performance. - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum, squared_sum, operand_squared}, - HloInstruction::FusionKind::kInput); - - sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - squared_sum = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // E[X]. - auto mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); + auto mean = add(Mean(elements_per_feature_int64, sum, add)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); + auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add)); // E^2[X]. - auto mean_square = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, mean, mean)); + auto mean_square = + add_binary(feature_shape, HloOpcode::kMultiply, mean, mean); // Var[X]. - auto var = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); + auto var = + add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square); auto var_broadcasted = add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd, + scaled_normalized, offset_broadcasted); auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); @@ -320,8 +322,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( + operand_shape, + computation_->AddInstruction( + HloInstruction::CreateConstant(std::move(epsilon_literal))), + {})); std::vector dimensions_without_feature; @@ -338,6 +343,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); auto scale_broadcasted = add( @@ -353,30 +362,23 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( @@ -424,6 +426,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); @@ -439,26 +445,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); const int64 feature_count = activation_shape.dimensions(feature_index); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = + auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon_activation = add( + HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {})); + auto epsilon_feature = + add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {})); std::vector dimensions_without_feature; @@ -478,29 +478,26 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, - epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = + add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation), + add)); + + auto rsqrt_var_add_epsilon = add(Rsqrt( + add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature), + add)); // X - E[X]. - auto activation_minus_mean = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); + auto activation_minus_mean = add_binary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted); // Grad[Y] * (X - E[X]). auto grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + activation_minus_mean); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = @@ -513,25 +510,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple( - {sum_grad_output_times_activiation_minus_mean, grad_beta})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, - HloInstruction::FusionKind::kInput); - - sum_grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - grad_beta = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, - sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); + auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply, + sum_grad_output_times_activiation_minus_mean, + rsqrt_var_add_epsilon); // I2 = Sum(Grad[Y]) auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, @@ -543,39 +525,40 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( {feature_index})); // I4 = (X - E[X]) * I3 - auto i4 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); + auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3, + activation_minus_mean); // I5 = I4 / (Var[X] + epsilon) - auto i5 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, i4, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)))); + auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4, + add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation)); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = + add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, - elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = add( + Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); - auto i1 = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto elements_per_feature_literal = + Literal::CreateR0(elements_per_feature_int64); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + add(HloInstruction::CreateBroadcast( + activation_shape, elements_per_feature, {}))); // I6 = I1 - I2 - I5 - auto i6 = add(HloInstruction::CreateBinary( + auto i6 = add_binary( activation_shape, HloOpcode::kSubtract, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - i1, i2)), - i5)); + add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, i6)); + auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6); auto tuple = HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); if (batch_norm->has_sharding()) { @@ -604,8 +587,8 @@ StatusOr BatchNormExpander::Run(HloModule* module) { bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_, - rewrite_inference_op_, rewrite_grad_op_, - use_fusion_)) { + rewrite_inference_op_, + rewrite_grad_op_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 4ad987085da91684bb7891070afeefd19be4138f..7ae202c583516443a6263403fb5460d1adbabd97 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -31,11 +31,10 @@ class BatchNormExpander : public HloPassInterface { // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, bool rewrite_inference_op = false, - bool rewrite_grad_op = false, bool use_fusion = true) + bool rewrite_grad_op = false) : rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} + rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; tensorflow::StringPiece name() const override { return "batchnorm_expander"; } @@ -47,7 +46,6 @@ class BatchNormExpander : public HloPassInterface { bool rewrite_training_op_; bool rewrite_inference_op_; bool rewrite_grad_op_; - bool use_fusion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 08d0152e3cfcfcb7ae1e85f72c2f7dc856f5e8b3..1b8b2d204503576c3fcb02f6d5b37f2db45e1768 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -182,15 +182,26 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape()) || - !bfloat16_support_->SupportsMixedPrecisions(*crs)) { - return DefaultAction(crs); - } - // First use DefaultAction() to handle the operands. It can't handle // tuple-shaped output. TF_RETURN_IF_ERROR(DefaultAction(crs)); + if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return Status::OK(); + } + + // If the output is not a tuple, we don't need special handling. + if (!ShapeUtil::IsTuple(crs->shape())) { + return Status::OK(); + } + + // If crs is the root instruction, we should keep its original output type. + // The root instruction implicitly has a use from being the result of the + // computation, and the code below does not take this use into account. + if (crs == computation_->root_instruction()) { + return Status::OK(); + } + // Then do per-tuple-element handling on the output. std::vector> per_tuple_element_gtes( crs->operand_count()); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 28e71c2054f59ba4d5d096bf7d898161877bb42f..7fd1e733e96da95cf43d9861af6d48a1850051c8 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -211,6 +211,17 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto builder = HloComputation::Builder(TestName()); + + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("add"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -223,7 +234,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b})); + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, + sum)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( @@ -233,7 +245,6 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* tuple = builder.AddInstruction( HloInstruction::CreateTuple({gte_a, convert_gte_b})); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(FoldConversions(module.get())); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 1afaefd9df9c5771fb9e134ae9050f3abb00ea4a..9926661dd30600b2bf20e7f137aa50d9fbfd7c82 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -228,6 +228,17 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -239,11 +250,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, + reduction)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(Normalize(module.get())); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 313910a861f7f4c0d1d60b738caef40e76cc4260..5e1499ee6b6ef397f95f7ed29e808d530777bd07 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -149,12 +149,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_TRUE(OutputsBF16(dot->operand(1))); EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( dot->operand(0)->literal(), - *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))); - LiteralTestUtil::ExpectEqual( + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)))); + EXPECT_TRUE(LiteralTestUtil::Equal( dot->operand(1)->literal(), - *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))); + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)))); } // Tests that BF16 can be propagated through nested tuples. diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 94ccfedf6289b4af1accebd358671c3e2bc10ba7..682c3865797c85eedf3949738f3372857f146c0e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -134,6 +135,7 @@ Status GatherComputationsByAllocationType( worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. break; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: @@ -699,7 +701,7 @@ BufferAssignmentProto BufferAssignment::ToProto() const { BufferAssignmentProto::BufferAlias* proto_alias = proto.add_buffer_aliases(); LogicalBufferProto::Location proto_alias_location = - LogicalBuffer::ToLocationProto(*alias.instruction(), alias.index()); + BufferValue::ToLocationProto(*alias.instruction(), alias.index()); proto_alias->set_source_buffer_id(buffer.id()); proto_alias->mutable_location()->Swap(&proto_alias_location); } @@ -1083,7 +1085,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; - options.buffers_to_assign = &single_colored_set.second; + BufferValueFlatSet buffer_value_set = + ToBufferValueFlatSet(single_colored_set.second); + options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( @@ -1111,7 +1115,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; - options.buffers_to_assign = &single_colored_set.second; + BufferValueFlatSet buffer_value_set = + ToBufferValueFlatSet(single_colored_set.second); + options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( @@ -1224,7 +1230,10 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( BufferAllocation* allocation = assignment->NewEmptyAllocation( result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true, color); for (const auto& buffer_chunk : result.chunk_map) { - const LogicalBuffer& buffer = *buffer_chunk.first; + // TODO(lauj) Remove this down_cast after downstream users of + // BufferAllocation::assigned_buffers() are updated to use BufferValue. + const LogicalBuffer& buffer = + *CHECK_NOTNULL(dynamic_cast(buffer_chunk.first)); const HeapSimulator::Chunk& chunk = buffer_chunk.second; assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 15fd905e8d593994c1cd5ec77cef6db7c2dbefdb..ad0b0bf7c25d7194a06801e4ef1c9ee961f6b915 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -415,10 +415,10 @@ class BufferAssignment { // Only BufferAssigner can build or modify BufferAssignments. friend class BufferAssigner; - explicit BufferAssignment(const HloModule* module, - std::unique_ptr liveness, - LogicalBuffer::SizeFunction buffer_size, - LogicalBuffer::AlignmentFunction color_alignment) + BufferAssignment(const HloModule* module, + std::unique_ptr liveness, + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment) : module_(module), liveness_(std::move(liveness)), buffer_size_(std::move(buffer_size)), diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index a4fb0eefaca094898ed9acad8062484d1a36afe7..96d25675deab6af2d7c0d2aa10cc4087093c17f6 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -33,12 +32,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" @@ -82,7 +81,7 @@ const std::vector GetInstructions(HloInstruction* root) { class BufferAssignmentTest : public HloTestBase { protected: - BufferAssignmentTest() : computation_tracker_() {} + BufferAssignmentTest() {} ~BufferAssignmentTest() override {} std::unique_ptr RunBufferAssignment(HloModule* module, @@ -252,9 +251,6 @@ class BufferAssignmentTest : public HloTestBase { return total_size; } - // Computation tracker for nested computations. - ComputationTracker computation_tracker_; - // Shapes for use in the examples. Shape s32_ = ShapeUtil::MakeShape(xla::S32, {}); Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {}); @@ -375,11 +371,11 @@ TEST_F(BufferAssignmentTest, Basic) { // param1[100] --------------/--------/ auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -422,11 +418,11 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { // share anything. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -481,11 +477,11 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { // have the color 0, which allows the mul and add to share buffers. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -551,11 +547,11 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { // auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -605,7 +601,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation)); module->AddEntryComputation(builder.Build()); @@ -658,7 +654,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0)); auto exp2 = builder.AddInstruction( @@ -822,7 +818,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32vec100_, "")); + HloInstruction::CreateParameter(0, f32vec100_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0)); auto tanh = builder.AddInstruction( @@ -1500,11 +1496,11 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { // param1[100] --------------/--------/ auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -1540,7 +1536,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { // be {%rev, %neg, %concat}. This occurs right at the concat itself. auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32vec100_, "")); + HloInstruction::CreateParameter(0, f32vec100_, "p")); auto log = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param)); auto rev = builder.AddInstruction( @@ -1797,7 +1793,7 @@ ENTRY %test_module { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 37982aaef9eddd64ef6b57ad5a9cf8dd6a565097..810d597e730c1823668c81598df6138655e58b55 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,7 +43,7 @@ StatusOr> BufferLiveness::Run( return std::move(liveness); } -tensorflow::Status BufferLiveness::Analyze() { +Status BufferLiveness::Analyze() { TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); for (auto* computation : module_->computations()) { if (computation->IsFusionComputation()) { @@ -71,7 +70,7 @@ tensorflow::Status BufferLiveness::Analyze() { } XLA_VLOG_LINES(3, ToString()); - return tensorflow::Status::OK(); + return Status::OK(); } string BufferLiveness::ToString() const { @@ -105,8 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (auto user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user, - points_to_analysis())) { + if (points_to_analysis().DoesNotUseOperandBuffer(alias.instruction(), + alias.index(), user)) { continue; } if (user != b.instruction() && @@ -132,9 +131,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // the qualifications specified in CanShareOperandBufferWithUser. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { if (b.instruction()->IsUserOf(alias.instruction()) && - !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), - b.instruction(), b.index(), - points_to_analysis())) { + !points_to_analysis().CanShareOperandBufferWithUser( + alias.instruction(), alias.index(), b.instruction(), b.index())) { return false; } } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index 11834a5127e383cc2ec2ab3fe1bb82ba86e4abed..cdd3cf4032ef6916086e1c2d148b575192503000 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -89,7 +89,7 @@ class BufferLiveness { // Perform buffer liveness analysis. This method must be called prior to // MayInterfere or MaybeLiveOut. - tensorflow::Status Analyze(); + Status Analyze(); // Returns true if the live range of the buffer of 'a' is strictly before the // live range of the buffer of 'b' (they do not overlap). diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h new file mode 100644 index 0000000000000000000000000000000000000000..305914fca828f110bf54239bddb1590172562b16 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value_containers.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ + +#include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/core/lib/gtl/compactptrset.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Define various containers of BufferValues, and utilities to convert from +// containers of LogicalBuffers to containers of BufferValues. + +using BufferValueCompactPointerSet = + tensorflow::gtl::CompactPointerSet; +template +BufferValueCompactPointerSet ToBufferValueCompactPointerSet( + const LogicalBufferContainerT& logical_buffer_container) { + BufferValueCompactPointerSet output; + for (const LogicalBuffer* buffer : logical_buffer_container) { + output.insert(buffer); + } + return output; +} + +using BufferValueFlatSet = tensorflow::gtl::FlatSet; +template +BufferValueFlatSet ToBufferValueFlatSet( + const LogicalBufferContainerT& logical_buffer_container) { + BufferValueFlatSet output; + output.reserve(logical_buffer_container.size()); + for (const LogicalBuffer* buffer : logical_buffer_container) { + output.insert(buffer); + } + return output; +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index a8053d15e124319c5c898f0034b9aaa95a007a89..a23427f00ccd88bb0fe1d973a667f80ca54b14cd 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -57,6 +57,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index c7763f2ca3e68490cd0cd9b4ba4d7bd180134080..fac0afd672ff3ed083aacf778dd9c4f90a2ee870 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -19,9 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc deleted file mode 100644 index b16907da9e9c909d2639f83895db27d724a84a7b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/compilation_cache.h" - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -std::shared_ptr CompilationCache::Insert( - std::unique_ptr executable, - const HloModuleConfig& module_config) { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = - BuildKey(executable->entry_computation_handle(), module_config); - VLOG(2) << "inserting cache key: " << key; - if (cache_.count(key) == 0) { - cache_.emplace(key, std::move(executable)); - } else { - // Executable already exists in the cache. This can happen if two Execute - // calls for a new computation are received simultaneously by the - // service. In this case, we discard the Executable given as a parameter and - // return what is in the cache. This is necessary because the service relies - // on the cache to keep ownership of the Executable. We only want to store - // one Executable for a given computation version and we can't discard the - // executable which is in the cache because it may be in use. - executable.reset(); - } - return cache_.at(key); -} - -std::shared_ptr CompilationCache::LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = BuildKey(versioned_handle, module_config); - VLOG(2) << "looking up cache key: " << key; - if (cache_.count(key) == 0) { - VLOG(2) << "cache key not found: " << key; - return nullptr; - } else { - std::shared_ptr result = cache_.at(key); - VLOG(2) << "hit executable with module config: " - << result->module_config().compilation_cache_key(); - return result; - } -} - -CompilationCache::CacheKey CompilationCache::BuildKey( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - // The computation shape is represented entirely by its ProgramShape member, - // so just serialize the proto as part of the key. - return tensorflow::strings::StrCat(versioned_handle.handle.handle(), "::", - versioned_handle.version, "::", - module_config.compilation_cache_key()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h deleted file mode 100644 index 09989726ae6629aa65cb1dd84c16408a75019fa5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.h +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" - -namespace xla { - -// A cache which stores Executables indexed by computation handle and version. -class CompilationCache { - public: - CompilationCache() {} - - // Insert the given Executable into the cache. Return a bare Executable - // pointer for the caller to use. Note: the returned pointer will *not* be the - // same as the given unique pointer if the computation already exists in the - // cache. See comments in the .cc implementation for details of this case. - // - // module_config is provided by the caller, instead of being taken from the - // executable, so that we can insert keys into the compilation cache that are - // devoid of layout (where XLA gets to choose what layout to compile). - // - // A shared_ptr is returned so the caller can keep the Executable from being - // destructed in the event that the Executable is evicted from the - // computation cache (and the cache's shared_ptr to the Executable is - // destructed). - std::shared_ptr Insert(std::unique_ptr executable, - const HloModuleConfig& module_config); - - // Lookup the Executable for the specified versioned computation in the cache. - // Return a shared_ptr to the Executable if it exists in the cache. Return - // nullptr otherwise. - std::shared_ptr LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - - protected: - mutable tensorflow::mutex mutex_; - - // Map from versioned handle with program layout to Executable built - // for that computation version and program layout. - using CacheKey = string; - - CacheKey BuildKey(const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - std::map> cache_ GUARDED_BY(mutex_); - - private: - TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index d39fd7307ae1b5bd0c431f98c413011ca081050b..d8fdccf9bbf1c1788bb4000aa702292362446503 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -104,56 +103,4 @@ CompileOnlyService::CompileAheadOfTime( return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); } -StatusOr>> -CompileOnlyService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector> hlo_modules; - for (const AotComputationInstance& instance : computations) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(instance.computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - const DebugOptions& debug_options = options.debug_options(); - - // Dump computation proto state if flag is set. - const string& directory_path = debug_options.xla_dump_computations_to(); - if (!directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - string filename = tensorflow::strings::StrCat( - "computation_", versioned_handle.handle.handle(), "__", - session_module->entry().name(), "__version_", - versioned_handle.version); - const string& per_host_path = tensorflow::io::JoinPath( - directory_path, tensorflow::port::Hostname()); - - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(per_host_path, filename, - *session_module)); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - ExecutionOptions execution_options; - *execution_options.mutable_debug_options() = debug_options; - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, user_computation)); - - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule( - versioned_handle, *module_config, - /*include_unreachable_instructions=*/true)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); - hlo_modules.push_back(std::move(hlo_module)); - } - - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index c10609e67fcdec459baf25a95173bbf700994be9..e6a66c202d6e0df3cb6d165e51beb25abd8ec45c 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -38,24 +38,7 @@ class CompileOnlyService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - ComputationHandle computation; - std::vector argument_layouts; - const Shape* result_layout = nullptr; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. See - // |CompileOnlyClient::CompileAheadOfTime| for additional details. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& Options); - // A description of a xla computation to compile using CompileAheadOfTime. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct AotXlaComputationInstance { HloModuleProto computation; std::vector argument_layouts; @@ -65,58 +48,36 @@ class CompileOnlyService : public Service { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. See // |CompileOnlyClient::CompileAheadOfTime| for additional details. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& options); - // Override Service methods that require or imply the existence of an - // execute backend. Note that this does not include TransferToClient, as - // computing constants produces global data that we may wish to transfer. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override { + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override { + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override { + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override { + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override { + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override { + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 8b01a6c4b5004d03e6e7d23b99b923fdcdeaff99..6f06bba6798bdff51f10d8fe9dc524d8064ba849 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -28,6 +28,13 @@ namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); +std::vector> +Compiler::ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const { + CHECK(executor != nullptr); + return {}; +} + /* static */ std::map* Compiler::GetPlatformCompilerFactories() { static auto* r = new std::map; diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index a4b59d1ba9b24e3f886a7feb51181ae8f990951f..6c52ffd800d19de83877341d41ef81eee2de7251 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -24,9 +24,11 @@ limitations under the License. #include #include #include +#include #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" @@ -34,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -153,6 +156,16 @@ class Compiler { std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) = 0; + // Returns the backend configurations that the backend will consider for the + // given HLO. Returns no configurations if the backend does not support + // configurations for the given HLO. + // + // The stream executor is passed in to provide information about the hardware + // that the backend configurations would be targeting. + virtual std::vector> + ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const; + // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index d2d4f14fcec35f5b51a2670a646154ce8bb9bfc1..cb61f3da39fb8eef69fd81066d87a1da91a62935 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -23,12 +23,15 @@ limitations under the License. namespace xla { -ComputationLayout::ComputationLayout(const ProgramShape& program_shape) +ComputationLayout::ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts) : result_layout_(program_shape.result()) { for (auto& shape : program_shape.parameters()) { parameter_layouts_.emplace_back(shape); } - SetToDefaultLayout(); + if (ignore_layouts) { + SetToDefaultLayout(); + } } void ComputationLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 80e102411c7885669947d89f378b1ec61e3e4e96..6975f387b4864bf28ea0ad23d7d4602b5b346e08 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -32,10 +32,20 @@ namespace xla { // mutable layouts. class ComputationLayout { public: + // Creates a new ComputationLayout with the given result layout. + explicit ComputationLayout(ShapeLayout result_layout) + : result_layout_(std::move(result_layout)) {} + // Constructs a ComputationLayout from a ProgramShape. The layouts of the // parameters and results are set to the default layout. Layouts in the - // ProgramShape are ignored. - explicit ComputationLayout(const ProgramShape& program_shape); + // ProgramShape are ignored if ignore_layouts is true. + explicit ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts = true); + + // Adds a new parameter layout to the computation layout. + void add_parameter_layout(ShapeLayout shape_layout) { + parameter_layouts_.push_back(std::move(shape_layout)); + } // Returns the layout of a particular parameter. const ShapeLayout& parameter_layout(int64 param_no) const { diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc deleted file mode 100644 index 70e25eebdb068db893e24aec0f72d09090ac7027..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.cc +++ /dev/null @@ -1,256 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/computation_tracker.h" - -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" - -using ::tensorflow::strings::Appendf; - -namespace xla { - -ComputationTracker::ComputationTracker() : next_computation_(1) {} - -ComputationHandle ComputationTracker::NewComputation( - const string& computation_name) { - tensorflow::mutex_lock lock(computation_mutex_); - ComputationHandle computation_handle; - int64 handle_value = next_computation_++; - computation_handle.set_handle(handle_value); - opaque_to_computation_[handle_value] = - MakeUnique(computation_name, computation_handle); - return computation_handle; -} - -StatusOr ComputationTracker::LoadSessionModule( - const SessionModule& session_module) { - tensorflow::mutex_lock lock(computation_mutex_); - - // For each embedded computation, create a new computation based on its - // serialized data, and place the mapping from the old computation handle to - // the new computation handle. - - // Build a mapping from old embedded computation handles to new computation - // handles. We build the ID mapping first since the embedded computations are - // in no particular order and may refer to each other. - std::map old_to_new; - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - if (!old_to_new.emplace(old_handle, AllocateHandle()).second) { - return InvalidArgument("Duplicate embedded computation handle %lld", - old_handle); - } - } - - // Create a new computation from each serialized embedded computation. - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - const ComputationHandle& new_handle = old_to_new[old_handle]; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - computation, new_handle, old_to_new)); - } - - // Finally, place the entry computation in the tracker with all of the - // remappings populated from the above. - const int64 old_handle = session_module.entry().computation_handle().handle(); - TF_ASSIGN_OR_RETURN( - old_to_new[old_handle], - LoadSessionComputation(session_module.entry(), &old_to_new)); - return old_to_new[old_handle]; -} - -StatusOr> -ComputationTracker::SnapshotComputation(const ComputationHandle& computation) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation)); - const VersionedComputationHandle entry_versioned_handle = - user_computation->GetVersionedHandle(); - std::set visited; - std::list post_order; - { - tensorflow::mutex_lock lock(computation_mutex_); - ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order); - } - auto session_module = MakeUnique(); - *session_module->mutable_entry() = - Resolve(entry_versioned_handle.handle) - .ValueOrDie() - ->CloneSessionComputation(entry_versioned_handle.version); - for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) { - *session_module->add_embedded_computations() = - Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version); - } - return std::move(session_module); -} - -StatusOr ComputationTracker::Resolve( - const ComputationHandle& computation) const { - tensorflow::mutex_lock lock(computation_mutex_); - return ResolveInternal(computation); -} - -ComputationHandle ComputationTracker::AllocateHandle() { - int64 handle_value = next_computation_++; - ComputationHandle result; - result.set_handle(handle_value); - return result; -} - -StatusOr ComputationTracker::LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) { - TF_RET_CHECK(old_to_new != nullptr); - const ComputationHandle new_handle = AllocateHandle(); - (*old_to_new)[session_computation.computation_handle().handle()] = new_handle; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - session_computation, new_handle, *old_to_new)); - return new_handle; -} - -StatusOr ComputationTracker::ResolveInternal( - const ComputationHandle& computation) const { - auto it = opaque_to_computation_.find(computation.handle()); - if (it == opaque_to_computation_.end()) { - return NotFound("computation handle not found: %lld", computation.handle()); - } - UserComputation* user_computation = it->second.get(); - return user_computation; -} - -void ComputationTracker::ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const { - if (visited->count(versioned_handle) > 0) { - CHECK_EQ(1, visited->count(versioned_handle)); - return; - } - - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - std::vector embedded_handles = - computation->GetEmbeddedComputations(versioned_handle.version); - - for (const auto& embedded_handle : embedded_handles) { - ComputeComputationPostOrder(embedded_handle, visited, post_order); - } - - visited->insert(versioned_handle); - post_order->push_back(versioned_handle); -} - -StatusOr> ComputationTracker::BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(computation_mutex_); - - VLOG(1) << "BuildHloModule(" << entry_handle - << ", include_unreachable_instructions=" - << include_unreachable_instructions << ")"; - XLA_VLOG_LINES(1, ToStringInternal()); - - TF_ASSIGN_OR_RETURN(UserComputation * entry_computation, - ResolveInternal(entry_handle.handle)); - - // Build a topological sort of the entry and any embedded computations as a - // list. The root of the computation will be the last element in the list. - std::set visited; - std::list post_order; - ComputeComputationPostOrder(entry_handle, &visited, &post_order); - - // Map from ComputationHandle value and computation version to HloComputation. - std::map hlo_computations; - - // The resolver lambda resolves VersionedHandles to embedded - // HloComputation*. This is required by UserComputation::BuildHloComputation - // when lowering calling operations (map, reduce etc). - auto resolver = [&hlo_computations]( - const VersionedComputationHandle& versioned_handle) -> HloComputation* { - CHECK_GT(hlo_computations.count(versioned_handle), 0); - return hlo_computations.at(versioned_handle); - }; - - // Print the post-order list for this entry computation. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Visiting UserComputations in post order:"; - for (const VersionedComputationHandle& versioned_handle : post_order) { - VLOG(2) << " " << versioned_handle; - } - } - - string module_name = - tensorflow::strings::StrCat(entry_computation->name(), "_module"); - auto module = MakeUnique(module_name, entry_handle, config); - for (auto versioned_handle : post_order) { - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - computation->BuildHloComputation(versioned_handle.version, resolver, - config.debug_options(), - include_unreachable_instructions)); - - // Add the newly created computation to VersionedHandle-to-HloComputation - // map. - DCHECK_EQ(0, hlo_computations.count(versioned_handle)); - hlo_computations[versioned_handle] = hlo_computation.get(); - - if (computation == entry_computation) { - module->AddEntryComputation(std::move(hlo_computation)); - } else { - module->AddEmbeddedComputation(std::move(hlo_computation)); - } - } - - return std::move(module); -} - -string ComputationTracker::ToString() const { - tensorflow::mutex_lock lock(computation_mutex_); - return ToStringInternal(); -} - -string ComputationTracker::ToStringInternal() const { - string out; - Appendf(&out, "ComputationTracker(%p):\n", this); - for (const auto& handle_computation : opaque_to_computation_) { - int64 handle = handle_computation.first; - const std::unique_ptr& computation = - handle_computation.second; - Appendf(&out, " %4lld : %s \"%s\"\n", handle, - computation->GetVersionedHandle().ToString().c_str(), - computation->name().c_str()); - } - return out; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_tracker.h b/tensorflow/compiler/xla/service/computation_tracker.h deleted file mode 100644 index d42d66adefe7faa2751da4cd80b392a38917ce70..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.h +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Tracks computations for the XLA service; computations can be registered -// with a UserComputation instance and can be resolved from a handle for later -// use. -// -// This class is also capable of serializing/deserializing computations that it -// tracks (and to serialize properly you need to serialize all referred-to -// computations as well). -class ComputationTracker { - public: - ComputationTracker(); - - // Creates a new UserComputation object and returns the corresponding - // ComputationHandle for it. - // - // Precondition: user_computation is not already present in the map. - ComputationHandle NewComputation(const string& computation_name); - - // Restores session data for a computation that has been serialized, and - // allocates a new computation handle for it. - StatusOr LoadSessionModule( - const SessionModule& session_module); - - // Snapshots a computation (referenced by the provided handle) at its latest - // version, returning a module where it is the entry, and any referred-to - // computations are entrained as "embedded" (non-entry) computations. - StatusOr> SnapshotComputation( - const ComputationHandle& computation); - - // Resolves a ComputationHandle to a UserComputation that is present in the - // map. - StatusOr Resolve( - const ComputationHandle& computation) const; - - // Builds an HLO module using the specified computation as the entry. The - // module will include the entry computation as well as all computations which - // are called directly or indirectly from the entry computation via operations - // like "map". config is the HLO module configuration to use for the - // constructed module. - // If include_unreachable_instructions is true, then instructions - // which are not reachable from the root are lowered into HloInstructions - // including unreachable parameters. This ensures the entry HloComputation has - // the same program shape (ProgramShape) as the entry UserComputation. - StatusOr> BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions = true) const; - - string ToString() const; - - private: - // Bumps the next_computation_ number and returns the allocated number wrapped - // in a ComputationHandle. - ComputationHandle AllocateHandle() - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Loads a session computation into a UserComputation, registers it, and - // returns the computation handle of the registered computation. If old_to_new - // is provided, it is used for remapping references to computations present in - // session_computation. - // - // old_to_new will be updated with the mapping from session_computation's old - // handle to the returned handle value, and may not be null. - StatusOr LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Internal implementation of Resolve method which requires, but does not - // acquire the mutex. - StatusOr ResolveInternal( - const ComputationHandle& computation) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Builds a post order sort of a computation ("entry") and all of its embedded - // computations including all transitively embedded computations. An embedded - // computation (the callee) will always appear in the sort before the - // computation which calls the embedded computation (the caller). Necessarily, - // the entry computation is the last element in the sort. visited and - // post_order should be empty when calling. post_order contains the post order - // sort when the function return. - void ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Guards the computation mapping. Marked mutable so that the Resolve method - // can remain const; Resolve does't really modify the tracker in any way, but - // it has to lock the mutex for safety. - mutable tensorflow::mutex computation_mutex_; - - // The next sequence number to assign to a computation, guarded by the same - // mutex as the mapping as they'll be mutated at the same time. - int64 next_computation_ GUARDED_BY(computation_mutex_); - - // Mapping from ComputationHandle value to the corresponding registered - // UserComputation object. - std::map> opaque_to_computation_ - GUARDED_BY(computation_mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index cbe2ba2e50ab213133196987cf486152edc9d785..33d8338809d4e8c7c4774f062c3dda5494543ca6 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 153f062d015e49db11c4c9ae0a2a61e76c020f02..684fff8a6fd4fd045a0de9df9660b887f32bdf40 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1636,8 +1636,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_SequentialWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1677,8 +1676,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_ParallelWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1750,8 +1748,7 @@ void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { std::vector tuple_params(num_tuple_inputs); for (int i = 0; i < num_iters; ++i) { auto builder = HloComputation::Builder("BM_ParallelWhiles"); - HloModule module("BM_ManyElementTuple", VersionedComputationHandle(), - config); + HloModule module("BM_ManyElementTuple", config); for (int j = 0; j < num_tuple_inputs; ++j) { tuple_params[j] = builder.AddInstruction( HloInstruction::CreateParameter(j, element_shape, "")); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 7e6d58c7fa5ccaf3e0a6f21d43a54906a3fbe408..b703be0f39e2032bc58479f0b957f9d8b01a77c3 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -103,6 +103,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:batch_dot_simplification", "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", @@ -125,6 +126,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_scheduling", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:indexed_array_analysis", "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", @@ -149,7 +151,14 @@ cc_library( "@llvm//:target", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep "@llvm//:x86_disassembler", # fixdeps: keep - ], + ] + select({ + "@org_tensorflow//tensorflow:linux_ppc64le": [ + "@llvm//:powerpc_disassembler", + "@llvm//:powerpc_code_gen", + ], + "//conditions:default": [ + ], + }), alwayslink = True, # Contains compiler registration ) @@ -176,6 +185,7 @@ cc_library( ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", + ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", "@llvm//:execution_engine", "@llvm//:core", @@ -295,6 +305,15 @@ cc_library( ], ) +cc_library( + name = "target_machine_features_fake", + testonly = 1, + hdrs = ["target_machine_features_fake.h"], + deps = [ + ":target_machine_features", + ], +) + cc_library( name = "ir_function", srcs = ["ir_function.cc"], @@ -336,6 +355,7 @@ cc_library( deps = [ ":cpu_options", ":cpu_runtime", + ":ir_emission_utils", ":target_machine_features", ":vector_support_library", "//tensorflow/compiler/xla:shape_util", @@ -408,7 +428,6 @@ cc_library( "//tensorflow/core:lib", "@llvm//:analysis", "@llvm//:core", - "@llvm//:execution_engine", "@llvm//:ipo", "@llvm//:mc", "@llvm//:object", @@ -505,7 +524,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], @@ -567,6 +585,22 @@ cc_library( ], ) +cc_library( + name = "runtime_single_threaded_fft", + srcs = [ + "runtime_fft_impl.h", + "runtime_single_threaded_fft.cc", + ], + hdrs = ["runtime_single_threaded_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_single_threaded_matmul", srcs = ["runtime_single_threaded_matmul.cc"], @@ -622,10 +656,10 @@ tf_cc_test( deps = [ ":cpu_instruction_fusion", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -660,6 +694,7 @@ cc_library( hdrs = ["ir_emission_utils.h"], deps = [ ":cpu_runtime", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", @@ -672,14 +707,15 @@ tf_cc_test( srcs = ["ir_emission_utils_test.cc"], deps = [ ":ir_emission_utils", + ":target_machine_features_fake", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -690,6 +726,7 @@ cc_library( deps = [ ":dot_op_emitter", ":ir_emission_utils", + ":target_machine_features", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:layout_assignment", @@ -703,6 +740,7 @@ tf_cc_test( srcs = ["cpu_layout_assignment_test.cc"], deps = [ ":cpu_layout_assignment", + ":target_machine_features_fake", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -727,6 +765,7 @@ cc_library( deps = [ ":cpu_runtime", ":ir_emission_utils", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -741,6 +780,7 @@ tf_cc_test( srcs = ["conv_canonicalization_test.cc"], deps = [ ":conv_canonicalization", + ":target_machine_features_fake", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", @@ -779,6 +819,7 @@ cc_library( ":dot_op_emitter", ":ir_emission_utils", ":shape_partition", + ":target_machine_features", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", @@ -791,6 +832,7 @@ tf_cc_test( deps = [ ":cpu_executable", ":parallel_task_assignment", + ":target_machine_features_fake", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -863,6 +905,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", ], @@ -913,3 +956,17 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +tf_cc_test( + name = "cpu_eigen_tensor_alignment_test", + size = "small", + srcs = ["cpu_eigen_tensor_alignment_test.cc"], + deps = [ + ":dot_op_emitter", + ":ir_emission_utils", + ":target_machine_features_fake", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 2136aeb3877685373efaf5bf702a42b39a63f082..0985b9297fe487f3523826cb0978c17775549735 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -33,7 +33,8 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { for (HloInstruction* hlo : module->entry_computation()->MakeInstructionPostOrder()) { if (hlo->opcode() == HloOpcode::kConvolution && - !PotentiallyImplementedAsEigenConvolution(*hlo)) { + !PotentiallyImplementedAsEigenConvolution(*hlo, + target_machine_features_)) { const ConvolutionDimensionNumbers& dnums = hlo->convolution_dimension_numbers(); auto input_batch_dim = dnums.input_batch_dimension(); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index 9b2c3d82eb673ce542cc03ec706015967dc975b6..e6fd1499edd0095395194200a5b444ad61e7e39d 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -32,12 +33,19 @@ namespace cpu { // convolutions can run faster. class ConvCanonicalization : public HloPassInterface { public: + explicit ConvCanonicalization( + const TargetMachineFeatures* target_machine_features) + : target_machine_features_(*target_machine_features) {} + ~ConvCanonicalization() override {} tensorflow::StringPiece name() const override { return "convolution-canonicalization"; } StatusOr Run(HloModule* module) override; + + private: + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 968f53d5c706651d2a470a853e0e9b601c0ed2df..375b017b09263c20c1b1ef8329f7e2f6a573dda4 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -89,7 +90,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - ConvCanonicalization conv_canonicalization; + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + ConvCanonicalization conv_canonicalization(&target_machine_features); EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); @@ -146,7 +151,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - ConvCanonicalization conv_canonicalization; + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + ConvCanonicalization conv_canonicalization(&target_machine_features); EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 3d2e24ca14eacd1a26e118a636dcaca5f2768f15..4c0e189e78674b709ccc5e05fa629ceb90dbda8c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" @@ -81,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" @@ -231,7 +233,10 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine) { + LLVMTargetMachineFeatures target_machine_features(target_machine); + // Optimization pipeline. HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(); @@ -248,8 +253,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(&target_machine_features); { auto& pass = pipeline.AddPass>("simplification"); @@ -258,8 +264,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/false); + /*rewrite_grad_op=*/true); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -278,10 +283,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass(); pass.AddPass(); } + pipeline.AddPass(); pipeline.AddPass( - [](const HloInstruction& dot, - const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot) + [&target_machine_features]( + const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return PotentiallyImplementedAsEigenDot(dot, target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -296,7 +303,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->device_entry_computation_layout()); + module->mutable_device_entry_computation_layout(), + &target_machine_features); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -316,8 +324,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // and thread synchronization dependencies which would likely increase // binary size (and most AOT applications are single-threaded). // TODO(b/29630486) Support multi-threaded AOT. - pipeline.AddPass(max_parallelism, - ShapeSizeBytesFunction()); + pipeline.AddPass( + max_parallelism, ShapeSizeBytesFunction(), &target_machine_features); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -470,7 +478,13 @@ StatusOr> CpuCompiler::RunHloPasses( VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); + std::unique_ptr jit_target_machine = + SimpleOrcJIT::InferTargetMachineForJIT( + CompilerTargetOptions(module->config()), + CodeGenOptLevel(module->config())); + + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false, + jit_target_machine.get())); VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); @@ -561,10 +575,11 @@ StatusOr> CpuCompiler::RunBackend( // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. + LLVMTargetMachineFeatures target_machine_features(jit->target_machine()); IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - jit->target_machine(), jit->external_constant_pool()); + &target_machine_features, jit->external_constant_pool()); for (auto embedded_computation : entry_computation->MakeEmbeddedComputationsList()) { @@ -706,7 +721,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true)); + TF_RETURN_IF_ERROR( + RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get())); VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); @@ -746,10 +762,11 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, &hlo_profile_index_map, &hlo_profile_printer_data)); } + LLVMTargetMachineFeatures target_machine_features(target_machine.get()); IrEmitter ir_emitter(*module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - target_machine.get(), + &target_machine_features, /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 65b05f04fa8d9c72e7bfb6978f6a6384dfbcf976..e56f9f01134f84b4698c078b750b0c1fdca7748e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" @@ -148,7 +149,8 @@ class CpuCompiler : public LLVMCompiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* module, bool is_aot_compile); + Status RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8727c72b6e42517b1859e98ecadb41bbceed761c --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace cpu { +namespace { + +// Test that we don't call into Eigen with tensors too small to be aligned +// reliably. + +class CpuEigenTensorAlignmentTest : public ::testing::Test {}; + +TEST_F(CpuEigenTensorAlignmentTest, EigenDotAlignment) { + string hlo_string = R"( +HloModule DotOperation + +ENTRY DotOperation { + arg0 = f32[5,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + ROOT dot = f32[5,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloInstruction* dot = module->entry_computation()->root_instruction(); + + TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( + [](int64 size) { return 1; }); + + EXPECT_FALSE( + PotentiallyImplementedAsEigenDot(*dot, target_machine_with_no_alignment)); + + TargetMachineFeaturesWithFakeAlignmentLogic + target_machine_with_full_alignment([](int64 size) { + return TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + + EXPECT_TRUE(PotentiallyImplementedAsEigenDot( + *dot, target_machine_with_full_alignment)); +} + +TEST_F(CpuEigenTensorAlignmentTest, EigenConvAlignment) { + string hlo_string = R"( +HloModule ConvOperation + +ENTRY ConvOperation { + arg0 = f32[1,2,1] parameter(0) + arg1 = f32[1,1,1] parameter(1) + ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1}, dim_labels=b0f_0io->b0f +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloInstruction* conv = module->entry_computation()->root_instruction(); + + TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( + [](int64 size) { return 1; }); + + EXPECT_FALSE(PotentiallyImplementedAsEigenConvolution( + *conv, target_machine_with_no_alignment)); + + TargetMachineFeaturesWithFakeAlignmentLogic + target_machine_with_full_alignment([](int64 size) { + return TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + + EXPECT_TRUE(PotentiallyImplementedAsEigenConvolution( + *conv, target_machine_with_full_alignment)); +} +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 32613b869078305edda97c11ac250f67de32b805..cf43b74c699ca8cbbef11a0abbaf4d69476f5d77 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -73,7 +73,7 @@ CpuExecutable::CpuExecutable( Status CpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers) { + std::vector* buffers) { CHECK_EQ(buffers->size(), assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); @@ -201,60 +201,18 @@ Status CpuExecutable::ExecuteComputeFunction( return Status::OK(); } -static void LogLiveAddresses( - tensorflow::gtl::ArraySlice buffers, - const std::vector& buffers_in_result) { - if (!VLOG_IS_ON(3)) { - return; - } - - CHECK_EQ(buffers.size(), buffers_in_result.size()); - std::vector live_out_buffers; - for (int i = 0; i < buffers.size(); ++i) { - if (buffers_in_result[i]) { - live_out_buffers.push_back(buffers[i].opaque()); - } - } - VLOG(3) << "Live addresses in output marking found " - << live_out_buffers.size() << " addresses:\n" - << tensorflow::str_util::Join( - live_out_buffers, ", ", [](string* out, const void* address) { - tensorflow::strings::StrAppend( - out, tensorflow::strings::Printf("%p", address)); - }); -} - -static Status DeallocateTempBuffers( - DeviceMemoryAllocator* allocator, se::Stream* stream, - tensorflow::gtl::ArraySlice buffers, - const std::vector& buffers_in_result) { - // Keep those buffers in the output of the marked live because they are needed - // by the service. They will be deallocated by the service. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR( - allocator->Deallocate(stream->parent()->device_ordinal(), &alloc)); - } - } - - return Status::OK(); -} - StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice allocated_buffers, - std::vector* buffers_in_result) { + tensorflow::gtl::MutableArraySlice buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( /*on_host_shape=*/host_result_shape(), /*on_device_shape=*/host_result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); - // Copy DeviceMemoryBase values which contain the array(s) of the result into - // the respective location in ShapedBuffer which is returned to the caller. + // Move OwningDeviceMemory values which contain the array(s) of the result + // into the respective location in ScopedShapedBuffer which is returned to the + // caller. TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus( [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { const auto& sources = this->GetRootPointsToSet().element(index); @@ -273,10 +231,9 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( CHECK(!slice.allocation()->is_entry_computation_parameter()); const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = allocated_buffers[buffer_index]; + OwningDeviceMemory& buffer = buffers[buffer_index]; CHECK(!buffer.is_null() || buffer.size() == 0); - *device_memory = buffer; - (*buffers_in_result)[buffer_index] = true; + *device_memory = buffer.Forget(); return Status::OK(); })); return std::move(result_buffer); @@ -292,23 +249,21 @@ StatusOr CpuExecutable::ExecuteOnStream( se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); + std::vector buffers(assignment_->Allocations().size()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result_buffer, - CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); - - // Free all buffers not in the result. - TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, - buffers_in_result)); + std::vector unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } + TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(), + arguments, unowning_buffers, + hlo_execution_profile)); - return std::move(result_buffer); + return CreateResultShapedBuffer(run_options, &buffers); } StatusOr CpuExecutable::ExecuteAsyncOnStream( @@ -324,30 +279,53 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - + std::vector buffers(assignment_->Allocations().size()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result_buffer, - CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); - - LogLiveAddresses(buffers, buffers_in_result); - - host_stream->EnqueueTask([this, run_options, arguments, buffers, - buffers_in_result, memory_allocator, stream]() { - // Failing a CHECK here is not great, but I don't see an obvious way to - // return a failed Status asynchronously. - TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments, - buffers, - /*hlo_execution_profile=*/nullptr)); - TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers, - buffers_in_result)); - }); + std::vector unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + CreateResultShapedBuffer(run_options, &buffers)); - return std::move(result_buffer); + // At this point, `unowning_buffers` contains unowning pointers to all of our + // buffers, and `buffers` contains owning pointers to the non-live-out + // buffers. Enqueue a task which keeps alive the non-live-out buffers. + // + // Logically we want this lambda to capture `buffers` by move, ultimately our + // functor needs to be wrapped in an std::function, and that requires its + // functor to be copyable. Thus we perpitrate the hack of capturing buffers + // "by shared pointer". + // + // We also need to change the types of some of the variables we capture: + // run_options needs to change from a pointer to a value type, and arguments + // needs to change from an ArraySlice into a vector. We use a struct instead + // of a lambda to make this explicit. + struct AsyncRunTask { + CpuExecutable* executable; + ServiceExecutableRunOptions run_options; + std::vector arguments; + std::vector unowning_buffers; + std::shared_ptr> buffers; + + void operator()() { + // Failing a CHECK here is not great, but I don't see an obvious way to + // return a failed Status asynchronously. + TF_CHECK_OK(executable->ExecuteComputeFunction( + &run_options.run_options(), arguments, unowning_buffers, + /*hlo_execution_profile=*/nullptr)); + } + }; + host_stream->EnqueueTask(AsyncRunTask{ + this, *run_options, + std::vector(arguments.begin(), arguments.end()), + unowning_buffers, + std::make_shared>(std::move(buffers))}); + + return std::move(result); } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 68ad38cba88720a04519fc2473fe6f9decbaaf93..8dd47bfb865e8a0552542f510d3365cff0d111e0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -92,7 +92,7 @@ class CpuExecutable : public Executable { // buffer is assigned for this element. Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers); + std::vector* buffers); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. @@ -102,16 +102,12 @@ class CpuExecutable : public Executable { tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile); - // Creates a ScopedShapedBuffer for holding the result of the computation. The - // addresses (DeviceMemoryBases) are set according to buffer assignment. - // 'buffers_in_result' should point to a vector of the same size as - // 'allocated_buffers'. An element in buffers_in_result is set to true if the - // corresponding buffer is live out of the computation (and thus contained in - // the returned ShapedBuffer). + // Creates a ScopedShapedBuffer for holding the result of the computation, + // moving buffers out of allocated_buffers and into the result as appropriate. + // The addresses are set according to buffer assignment. StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice allocated_buffers, - std::vector* buffers_in_result); + tensorflow::gtl::MutableArraySlice buffers); // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 46fe060817b0264d90574b45a94cf1f6e5964593..97e10a89a209c057685709e7a5034052ff4376ed 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -172,7 +172,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -202,7 +202,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -233,7 +233,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -775,7 +775,7 @@ TEST_P(GatherLoopFusionTest, GatherLoopFusion) { string hlo_string = tensorflow::strings::StrCat( "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); RunFusionAndCheckOpcodesWereFused( module.get(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index 6c642080c34e72b1f28b13b340fd2e919a453201..aa872d5ec9e7593b8d2f731421c17af590729529 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -100,7 +100,8 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) { + PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_)) { const HloInstruction* convolution = instruction; const HloInstruction* lhs_instruction = convolution->operand(0); const HloInstruction* rhs_instruction = convolution->operand(1); @@ -126,7 +127,8 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( ColMajorShape(op->shape()), instruction, *op_idx)); - } else if (PotentiallyImplementedAsEigenDot(*instruction)) { + } else if (PotentiallyImplementedAsEigenDot(*instruction, + target_machine_features_)) { const HloInstruction* dot = instruction; // In order to implement `dot` with Eigen dot, the layouts of the lhs, // rhs, and output need to be row-major. @@ -177,7 +179,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( } } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 09adb5cb02abba5844a1740bdb50a578e1bdf8b5..3c4fe68b830d9602f009b318d4e51e9a04a27e09 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" @@ -28,12 +29,16 @@ namespace cpu { class CpuLayoutAssignment : public LayoutAssignment { public: explicit CpuLayoutAssignment( - const ComputationLayout& entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + ComputationLayout* entry_computation_layout, + const TargetMachineFeatures* target_machine_features) + : LayoutAssignment(entry_computation_layout), + target_machine_features_(*target_machine_features) {} ~CpuLayoutAssignment() override {} protected: Status AddBackendConstraints(LayoutConstraints* constraints) override; + + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index ba4c5a23d3e043fd6680c2f9abc2275696737ee7..429fc7b78608da0e9cd794ac294851b326f5be24 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -49,7 +50,12 @@ class CpuLayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout) { - cpu::CpuLayoutAssignment layout_assignment(*entry_computation_layout); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout, + &target_machine_features); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -311,7 +317,12 @@ static StatusOr RunDotOutputFusion( result.addend_fusion_param = fusion_instruction->operand( fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); - cpu::CpuLayoutAssignment layout_assignment(computation_layout); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + cpu::CpuLayoutAssignment layout_assignment(&computation_layout, + &target_machine_features); TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, layout_assignment.Run(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index f9c51f243c47b8069500eca3c9c2929b17f04e62..3ed7876715f64191f6e652d2b5cb1673df9a1b94 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -16,12 +16,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; +const char* const kXlaEnableExperimentalLlvmIrGemm = + "xla_enable_experimental_llvm_ir_gemm"; +const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -54,6 +58,49 @@ tensorflow::gtl::optional LlvmIrGemvTilingFactor( return tensorflow::gtl::nullopt; } +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; +} + +static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str, + tensorflow::StringPiece suffix) { + CHECK_GE(str.size(), suffix.size()); + CHECK_EQ(str.substr(str.size() - suffix.size()), suffix); + return str.substr(0, str.size() - suffix.size()); +} + +tensorflow::gtl::optional> LlvmIrGemmTileSize( + const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + auto it = extra_options_map.find(kLlvmIrGemmTileSize); + if (it == extra_options_map.end()) { + return tensorflow::gtl::nullopt; + } + + std::vector tile_components = + tensorflow::str_util::Split(it->second, ':'); + CHECK_EQ(tile_components.size(), 3); + + int64 tile_size_m; + int64 tile_size_k; + int64 tile_size_n_in_vector_width; + + CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m)); + CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k)); + + tensorflow::StringPiece tile_size_n_in_vector_width_str = + RemoveSuffix(tile_components[2], "*vectwidth"); + + CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str, + &tile_size_n_in_vector_width)); + + return std::tuple(tile_size_m, tile_size_k, + tile_size_n_in_vector_width); +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index be62ff3cc1af23408ca8a00f1372e7a998f160c6..429b9e16cbdd6f623919533582481f1640118081 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -26,8 +26,11 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); tensorflow::gtl::optional LlvmIrGemvTilingFactor( const HloModuleConfig& config); +tensorflow::gtl::optional> LlvmIrGemmTileSize( + const HloModuleConfig& config); } // namespace options } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 215405f6802cf1956ebec011da2fcd11b95c0c64..54c52bc08f9c53b8c6898689b18c4cb7f4bdcfd0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -51,6 +51,8 @@ extern const char* const kEigenConvF16SymbolName = extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; +extern const char* const kEigenSingleThreadedFftSymbolName = + "__xla_cpu_runtime_EigenSingleThreadedFft"; extern const char* const kEigenSingleThreadedMatMulF16SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 1dce6efa5cd65e67ae73a2e2affe2d2d3c537508..aa0e96712302e806a389c6ad05a2c1b6634ef901 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -52,6 +52,7 @@ extern const char* const kMKLSingleThreadedMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; +extern const char* const kEigenSingleThreadedFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 9b39e7f5765ae5eb6a25c06eef4d74b1c00e5c91..d97802ee45d6add3c466577d7624d9ca74e2f380 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -88,8 +88,8 @@ CpuTransferManager::CpuTransferManager() : GenericTransferManager(se::host::kHostPlatformId, /*pointer_size=*/sizeof(void*)) {} -Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status CpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 3ecb0d236498371f48caf63249f9cd4e8777752b..6dfc666f09dfa6df740cd54bea0957e3144181bc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -38,7 +38,7 @@ class CpuTransferManager : public GenericTransferManager { ~CpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 8db4a0650d2867cd7326206787d79aaa7c0acf9f..8eb39d615fd482cdcea716ba7b105c643a2d8b87 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -41,17 +42,17 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { namespace { -// Loads a tile of values from a 2D tensor. -class TileLoader { +// Provides tiled access to an in-memory rank 2 array. +class MemoryTile { public: - // Constructs a TileLoader that will load a tile consisting of + // Constructs a MemoryTile that can operate on tiles consisting of // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at // `major_dim_offset` in the major dimension. The tile size along the minor // dimension is the vector size, and that is implicitly determined by `vsl`. - TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, + MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, llvm::Value* matrix, int64 matrix_size_along_minor_dim, llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl) { + : vsl_(vsl), ir_builder_(ir_builder) { pointers_.reserve(tile_size_along_major_dim); for (int64 i = 0; i < tile_size_along_major_dim; i++) { llvm::Value* total_offset = ir_builder->CreateMul( @@ -61,9 +62,10 @@ class TileLoader { } } - // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at - // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the - // minor dimension. + // Load a tile consisting of `tile_size_along_major_dim` vectors from position + // {major: `major_dim_offset`, minor: `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. std::vector LoadTile(llvm::Value* minor_dim_offset) const { std::vector result; result.reserve(pointers_.size()); @@ -73,11 +75,104 @@ class TileLoader { return result; } + // Stores `tile` to position {major: `major_dim_offset`, minor: + // `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + void StoreTile(tensorflow::gtl::ArraySlice tile, + llvm::Value* minor_dim_offset) const { + CHECK_EQ(tile.size(), pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); + } + } + + // Loads a tile of size [`tile_size_along_major_dim`, + // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, + // minor: `minor_dim_offset`} and then broadcasts each element into a vector + // of size vsl_.vector_size(). The (i,j)'th element of the return value is + // the (i,j)'th element in the tile broadcasted into an LLVM vector. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector> LoadBroadcastTile( + llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { + std::vector> result; + result.resize(pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + for (int64 j = 0; j < tile_size_along_middle_dim; j++) { + result[i].push_back(vsl_->LoadBroadcast( + pointers_[i], ir_builder_->CreateAdd(minor_dim_offset, + ir_builder_->getInt64(j)))); + } + } + return result; + } + private: VectorSupportLibrary* vsl_; + llvm::IRBuilder<>* ir_builder_; std::vector pointers_; }; +// The base class for the classes representing the GEMV emitter configurations. +// +// The IR emitted (modulo the LLVM values representing the input and output +// buffers) by the row major and column major GEMV emitters should be a function +// of their configuration. This is important because their configuration is +// used as a key to cache the generated IR. +class GemvConfig { + public: + // Mixin for convenience. + template + struct User { + public: + PrimitiveType scalar_type() const { + return derived().config().scalar_type(); + } + int64 tile_rows() const { return derived().config().tile_rows(); } + int64 tile_cols() const { return derived().config().tile_cols(); } + int64 m() const { return derived().config().m(); } + int64 k() const { return derived().config().k(); } + int64 has_addend() const { return derived().config().has_addend(); } + + private: + const T& derived() const { return *static_cast(this); } + }; + + PrimitiveType scalar_type() const { return scalar_type_; } + int64 tile_rows() const { return tile_rows_; } + int64 tile_cols() const { return tile_cols_; } + int64 m() const { return m_; } + int64 k() const { return k_; } + bool has_addend() const { return has_addend_; } + + string GetCacheKey() const { + return tensorflow::strings::StrCat( + name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_", + tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : ""); + } + + protected: + explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, bool has_addend) + : name_(std::move(name)), + scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + has_addend_(has_addend) {} + + private: + string name_; + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + bool has_addend_; +}; + // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the // layout of the vector does not matter). This implementation uses a tiling // scheme to improve performance. @@ -139,38 +234,46 @@ class TileLoader { // TODO(sanjoy): We should investigate if using gather loads and scatter stores // can be used here have the same inner loop for both column-major and row-major // matrix-vector products. -class ColumnMajorMatrixVectorProductEmitter { +class ColumnMajorMatrixVectorProductEmitter + : public GemvConfig::User { public: - ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, - int64 tile_rows, int64 tile_cols, - int64 m, int64 k, llvm::Value* lhs, + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"col_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), + : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") { - CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast(tile_rows_))); + vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), + ir_builder_, "") { + CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); + CHECK(!has_addend() || addend != nullptr); } void Emit(); + const Config& config() const { return config_; } + private: void EmitOuterLoopBody(llvm::Value* column, int64 column_count, bool is_first_column); - TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/m_, + MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { + return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m(), /*major_dim_offset=*/column_start, /*tile_size_along_major_dim=*/column_count); } @@ -187,18 +290,14 @@ class ColumnMajorMatrixVectorProductEmitter { return result; } - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, int64 columns, bool is_first_column); void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column); - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; + Config config_; llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* addend_; @@ -210,26 +309,26 @@ class ColumnMajorMatrixVectorProductEmitter { void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( llvm::Value* column, int64 column_count, bool is_first_column) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column, + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, /*column_count=*/column_count); std::vector rhs_tile = LoadRhsTile(column, /*count=*/column_count); - EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile, + EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, /*columns=*/column_count, is_first_column); EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); } void ColumnMajorMatrixVectorProductEmitter::Emit() { // See the comment on the class declaration for the algorithm used here. - int64 column_remainder = k_ % tile_cols_; - int64 column_limit = k_ - column_remainder; + int64 column_remainder = k() % tile_cols(); + int64 column_limit = k() - column_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_, - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols_, is_first_column); - }); + ksl_.ForReturnVoid("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); if (column_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, @@ -238,29 +337,30 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { } void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, const std::vector& rhs_tile, + MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, int64 columns, bool is_first_column) { - int64 row_limit = m_ - (m_ % tile_rows_); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows_, [&](llvm::Value* row) { - std::vector lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = - is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) - : vsl_.GetZeroVector()) - : vsl_.LoadVector(result_, row); - for (int i = 0; i < columns; i++) { - accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); - } - vsl_.StoreVector(accumulator, result_, row); - }); + int64 row_limit = m() - (m() % tile_rows()); + + ksl_.ForReturnVoid( + "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows(), [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); } void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { - int64 row_start = m_ - (m_ % tile_rows_); - if (row_start == m_) { + int64 row_start = m() - (m() % tile_rows()); + if (row_start == m()) { return; } @@ -273,25 +373,25 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // initialized. // } - ksl_.For( + ksl_.ForReturnVoid( "dot.inner.epilg.outer", /*start=*/current_tile_col, /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), /*step=*/1, /*peel_first_iteration=*/false, [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); llvm::Value* total_offset = - ir_builder_->CreateMul(col, ir_builder_->getInt64(m_)); + ir_builder_->CreateMul(col, ir_builder_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For( - "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_, + ksl_.ForReturnVoid( + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( is_first_scalar_col, ir_builder_->getInt1(is_first_tiled_column)); - ksl_.If( + ksl_.IfReturnVoid( setting_result_first_time, /*true_block_generator=*/ [&]() { @@ -364,51 +464,55 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer // epilogue loop to deal with the C,D submatrix. -class RowMajorMatrixVectorProductEmitter { +class RowMajorMatrixVectorProductEmitter + : public GemvConfig::User { public: - RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, - int64 tile_cols, int64 m, int64 k, - llvm::Value* lhs, llvm::Value* rhs, - llvm::Value* addend, llvm::Value* result, + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"row_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), + : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") { - CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast(tile_cols_))); + vsl_(scalar_type(), /*vector_size=*/tile_cols(), ir_builder_, "") { + CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); + CHECK(!has_addend() || addend != nullptr); } void Emit(); + const Config& config() const { return config_; } + private: - TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/k_, + MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { + return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k(), /*major_dim_offset=*/row_start, /*tile_size_along_major_dim=*/row_count); } void EmitOuterLoopBody(llvm::Value* row, int64 row_count); - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows, + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, std::vector* vector_accumulators); void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, std::vector* scalar_accumulators); - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; + Config config_; llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* addend_; @@ -420,7 +524,7 @@ class RowMajorMatrixVectorProductEmitter { void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, int64 row_count) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row, + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, /*row_count=*/row_count); std::vector vector_accumulators; std::vector scalar_accumulators; @@ -428,7 +532,7 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); } - EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count, + EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, &vector_accumulators); EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, &scalar_accumulators); @@ -465,12 +569,13 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, void RowMajorMatrixVectorProductEmitter::Emit() { // See the comment on the class declaration for the algorithm used here. - int64 row_remainder = m_ % tile_rows_; - int64 row_limit = m_ - row_remainder; + int64 row_remainder = m() % tile_rows(); + int64 row_limit = m() - row_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_, - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); }); + ksl_.ForReturnVoid( + "dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); @@ -478,48 +583,395 @@ void RowMajorMatrixVectorProductEmitter::Emit() { } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, int64 rows, + MemoryTile* lhs_memory_tile, int64 rows, std::vector* vector_accumulators) { - int64 column_limit = k_ - (k_ % tile_cols_); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols_, [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set( - vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); + int64 column_limit = k() - (k() % tile_cols()); + + ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set(vsl_.Add( + old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* current_tile_row, int64 rows, std::vector* scalar_accumulators) { - int64 column_start = k_ - (k_ % tile_cols_); - if (column_start == k_) { + int64 column_start = k() - (k() % tile_cols()); + if (column_start == k()) { return; } for (int r = 0; r < rows; r++) { llvm::Value* total_offset = ir_builder_->CreateMul( ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), - ir_builder_->getInt64(k_)); + ir_builder_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_, - /*step=*/1, [&](llvm::Value* scalar_col) { - llvm::Value* product = - vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), - vsl_.LoadScalar(rhs_, scalar_col)); - llvm::Value* old_value = (*scalar_accumulators)[r].Get(); - (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); - }); + ksl_.ForReturnVoid( + "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); } } +// This class implements a tiled matrix multiplication algorithm, intended for +// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto, +// Kazushige, and Robert Van De Geijn. "High-performance implementation of the +// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008): +// 4). +// +// This only supports canonical dot operations (i.e. where the lhs contraction +// dimension is 1 and the rhs contraction dimension is 0) over row major +// matrices. +class MatrixMatrixBlockPanelEmitter { + public: + // Describe the dimensions of the GEBP kernel. These will usually not be the + // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP + // kernels with smaller dimensions. + class Dimensions { + public: + explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + + int64 m() const { return m_; } + int64 k() const { return k_; } + int64 n() const { return n_; } + + string ToString() const { + return tensorflow::strings::StrCat(m(), "x", k(), "x", n()); + } + + private: + const int64 m_; + const int64 k_; + const int64 n_; + }; + + // Represents the configuration of the GEBP emitter. The LLVM IR emitted by + // the emitter, modulo the LLVM values holding the input and output buffers, + // must be a function of the instance of `Config` passed to it. + // + // `dims` holds the matrix multiplication dimensions. + // + // `max_vectorization_width` is the maximum vector width (i.e. the width of + // the largest vector register we will use). This can be larger than the + // largest vector register supported by the machine -- LLVM will legalize + // these large vector widths into legally sized vectors. + // + // `max_vector_count` is the maximum number of vectors of size + // `max_vectorization_width` that we will attempt to process at once. + // + // `min_vectorization_width` is the smallest vector width the emitter will use + // -- below that it will devolve to using a scalar loop. + // + // The innermost reduction loop executes the matrix multiply in tiles of size + // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, + // ] in the RHS. + class Config { + public: + explicit Config(PrimitiveType scalar_type, Dimensions dims, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k) + : scalar_type_(scalar_type), + dims_(dims), + max_vectorization_width_(max_vectorization_width), + max_vector_count_(max_vector_count), + min_vectorization_width_(min_vectorization_width), + tile_size_m_(tile_size_m), + tile_size_k_(tile_size_k) {} + + string GetCacheKey() const { + return tensorflow::strings::StrCat( + "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), + "_", max_vectorization_width(), "_", min_vectorization_width(), "_", + tile_size_m(), "_", tile_size_k()); + } + + PrimitiveType scalar_type() const { return scalar_type_; } + Dimensions dims() const { return dims_; } + int64 max_vectorization_width() const { return max_vectorization_width_; } + int64 max_vector_count() const { return max_vector_count_; } + int64 min_vectorization_width() const { return min_vectorization_width_; } + + int64 tile_size_m() const { return tile_size_m_; } + int64 tile_size_k() const { return tile_size_k_; } + + private: + PrimitiveType scalar_type_; + Dimensions dims_; + int64 max_vectorization_width_; + int64 max_vector_count_; + int64 min_vectorization_width_; + int64 tile_size_m_; + int64 tile_size_k_; + }; + + // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* ir_builder) + : lhs_(lhs), + rhs_(rhs), + result_(result), + config_(config), + ir_builder_(ir_builder), + ksl_(ir_builder_) { + CHECK(max_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(max_vectorization_width()))); + CHECK_GT(max_vector_count(), 0); + CHECK(min_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(min_vectorization_width()))); + CHECK_GE(max_vectorization_width(), min_vectorization_width()); + CHECK_GT(tile_size_k(), 0); + } + + void Emit(); + + private: + // The HandleResiduesOnX helpers split the iteration space for dimension X + // into a multiple of the tile size on dimension X and an epilogue. These + // helpers ultimately call into `EmitTiledGemm` for emitting the + // tiled GEMM kernel. + + void HandleResiduesOnN(); + void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end); + + // This emits a tiled GEMM kernel. For a detailed description see the comment + // on the implementation. + void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); + + llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); } + + Config config() const { return config_; } + Dimensions dims() const { return config().dims(); } + + int64 max_vectorization_width() const { + return config().max_vectorization_width(); + } + int64 max_vector_count() const { return config().max_vector_count(); } + int64 min_vectorization_width() const { + return config().min_vectorization_width(); + } + int64 tile_size_m() const { return config().tile_size_m(); } + int64 tile_size_k() const { return config().tile_size_k(); } + PrimitiveType scalar_type() const { return config().scalar_type(); } + + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + Config config_; + + llvm::IRBuilder<>* ir_builder_; + KernelSupportLibrary ksl_; +}; + +void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); } + +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. + + int64 current_vectorization_width = + max_vector_count() * max_vectorization_width(); + int64 current_vector_count = max_vector_count(); + + int64 n_start = 0; + while (n_start != dims().n() && + current_vectorization_width >= min_vectorization_width()) { + int64 n_end = dims().n() - (dims().n() % current_vectorization_width); + if (n_start != n_end) { + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, + ir_builder_, "gebp"); + HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); + n_start = n_end; + } + if (current_vector_count == 1) { + current_vectorization_width /= 2; + } else { + current_vector_count--; + current_vectorization_width = + current_vector_count * max_vectorization_width(); + } + } + + if (n_start != dims().n()) { + VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp"); + ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + llvm::Value* n_i_next = + ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); + HandleResiduesOnK(&vsl, n_i, n_i_next); + }); + } +} + +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { + int64 k_start = 0; + int64 k_end = dims().k() - (dims().k() % tile_size_k()); + if (k_end != k_start) { + HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), + n_start, n_end); + k_start = k_end; + } + + if (k_start != dims().k()) { + HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), + GetInt64(dims().k()), n_start, n_end); + } +} + +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { + const int64 m_end = dims().m() - dims().m() % tile_size_m(); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), + GetInt64(0), GetInt64(m_end)); + + if (m_end != dims().m()) { + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); + } +} + +// The loop structure is: +// +// Iterate over dimension M as m: +// Iterate over dimension N as n: +// Iterate over dimension K as k: +// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) +// +// I.e. a just a tiled version of a "naive" GEMM. +// +// The tiling scheme is as follows: +// +// Let the LHS be: +// +// +----+----+----+ +// | a0 | b0 | c0 | . +// +----+----+----+ . +// | a1 | b1 | c1 | . +// +----+----+----+ +// .. .. +// +// and the RHS be: +// +// +----+----+----+----+ +// | p0 | p1 | p2 | p3 | . +// +----+----+----+----+ . +// | q0 | q1 | q2 | q3 | . +// +----+----+----+----+ +// | r0 | r1 | r2 | r3 | . +// +----+----+----+----+ . +// ...... ...... +// +// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted +// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] +// matrix that we can increment the result matrix by. +// +// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank +// 3 array, L, of dimension [2,3,4]: +// +// L[0,_,_] * L[1,_,_] +// * +// +----+----+----+----+ * +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | +// +----+----+----+----+ * +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | +// +----+----+----+----+ * +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | +// +----+----+----+----+ * +----+----+----+----+ +// +// +// Then we FMA L[0,_,_] with the RHS to get the first row of the result and +// L[1,_,_] with the RHS to get the second row of the result. For example, +// L[0,_,_] is computed as: +// +// +----+----+----+----+ +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | +// +----+----+----+----+ +----+----+----+----+ +// +// to get: +// +// +-------------------+-------------------+-------------------+--------- +// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... +// +-------------------+-------------------+-------------------+--------- +void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { + ksl_.ForReturnVoid( + "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile( + vsl, ir_builder_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + ksl_.ForReturnVoid( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + TileVariable result_tile_var(vsl, + result_memory_tile.LoadTile(n_i)); + ksl_.ForReturnVoid( + "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, + dims().n(), k_i, tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); + std::vector rhs_tile = + rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = + result_tile_var.Get(); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } + } + result_tile_var.Set(result_tile); + }); + + result_memory_tile.StoreTile(result_tile_var.Get(), n_i); + }); + }); +} + } // namespace DotOpEmitter::DotOpEmitter(const HloInstruction& dot, @@ -541,7 +993,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} -/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( +/* static */ Status DotOpEmitter::EmitDotOperation( const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, @@ -557,6 +1009,89 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, return dot_emitter.Emit(); } +bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( + const DotOpEmitter::MatMultDims& mat_mult_dims) { + if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) { + return false; + } + + if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { + return false; + } + + PrimitiveType primitive_type = dot_.shape().element_type(); + + switch (primitive_type) { + default: + return false; + + case F32: + case F64: + case S32: + case S64: + break; + } + + if (!(mat_mult_dims.lhs_column_major == mat_mult_dims.rhs_column_major && + mat_mult_dims.rhs_column_major == mat_mult_dims.target_column_major)) { + return false; + } + + llvm::Value* lhs = lhs_array_.GetBasePointer(); + llvm::Value* rhs = rhs_array_.GetBasePointer(); + llvm::Value* target = target_array_.GetBasePointer(); + int64 m = mat_mult_dims.m; + int64 k = mat_mult_dims.k; + int64 n = mat_mult_dims.n; + + if (mat_mult_dims.lhs_column_major) { + std::swap(lhs, rhs); + std::swap(m, n); + } + + int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + ir_builder_->CreateMemSet( + target, ir_builder_->getInt8(0), size_bytes, + target_machine_features_.minimum_alignment_for_allocation(size_bytes)); + + int64 max_target_vector_width = + target_machine_features_.vector_register_num_elements( + *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + + int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width; + std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = + GetGemmTileSize(); + + MatrixMatrixBlockPanelEmitter::Config config( + /*scalar_type=*/primitive_type, + MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + /*max_vectorization_width=*/max_target_vector_width, + /*max_vector_count=*/tile_size_n_in_vector_width, + /*min_vectorization_width=*/std::min(4, max_target_vector_width), + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); + + VLOG(2) << "Emitting GEBP kernel in LLVM IR with config " + << config.GetCacheKey(); + + const bool enable_fast_math = + hlo_module_config_.debug_options().xla_enable_fast_math(); + const bool optimize_for_size = + options::OptimizeForSizeRequested(hlo_module_config_); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs, rhs, target, + [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { + MatrixMatrixBlockPanelEmitter gebp_emitter( + config, /*lhs=*/lhs, /*rhs=*/rhs, + /*result=*/target, ir_builder_); + gebp_emitter.Emit(); + }); + + return true; +} + bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (dot_.shape().dimensions_size() != 2) { return false; @@ -609,7 +1144,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return false; + return EmitExperimentalGebpDotIfEnabled(mat_mult_dims); } int64 tiling_factor = GetGemvTilingFactor(); @@ -642,47 +1177,39 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - int64 tile_rows = vector_register_element_size; - int64 tile_cols = tiling_factor; - - string kernel_name = tensorflow::strings::StrCat( - "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, - "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + ColumnMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/primitive_type, + /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor, + /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, - lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, tile_rows, tile_cols, m, k, primitive_type]( - llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, - llvm::Value* result_op) { + [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, + llvm::Value* addend_op, llvm::Value* result_op) { ColumnMajorMatrixVectorProductEmitter emitter( - primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, - addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); emitter.Emit(); }); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - int64 tile_rows = tiling_factor; - int64 tile_cols = vector_register_element_size; - - string kernel_name = tensorflow::strings::StrCat( - "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, - "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + RowMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/primitive_type, + /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size, + /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, - lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, tile_rows, tile_cols, m, k, primitive_type]( - llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, - llvm::Value* result_op) { + [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, + llvm::Value* addend_op, llvm::Value* result_op) { RowMajorMatrixVectorProductEmitter emitter( - primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, - addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); emitter.Emit(); }); } @@ -690,7 +1217,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { return true; } -tensorflow::Status DotOpEmitter::Emit() { +Status DotOpEmitter::Emit() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. // @@ -734,7 +1261,7 @@ tensorflow::Status DotOpEmitter::Emit() { CHECK_EQ(addend_array_, nullptr); - if (PotentiallyImplementedAsEigenDot(dot_)) { + if (PotentiallyImplementedAsEigenDot(dot_, target_machine_features_)) { return EmitCallToRuntime(); } @@ -774,8 +1301,11 @@ tensorflow::Status DotOpEmitter::Emit() { // from messing up the vectorization. std::unique_ptr reduction_loop = loop_nest.AddLoop( 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction", - /*prevent_unrolling=*/lhs_reduction_along_minor_dimension && - rhs_reduction_along_minor_dimension); + /*unroll_mode=*/ + (lhs_reduction_along_minor_dimension && + rhs_reduction_along_minor_dimension) + ? xla::llvm_ir::UnrollMode::kNoUnroll + : xla::llvm_ir::UnrollMode::kDefaultUnroll); // The final entry in the rhs and lhs indexes is the indvar of the // reduction loop. @@ -868,10 +1398,10 @@ tensorflow::Status DotOpEmitter::Emit() { // loop. ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitScalarDot() { +Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; llvm::Value* lhs_value = @@ -896,10 +1426,10 @@ tensorflow::Status DotOpEmitter::EmitScalarDot() { result = ir_builder_->CreateFMul(lhs_value, rhs_value); } target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitCallToRuntime() { +Status DotOpEmitter::EmitCallToRuntime() { // The signature of the Eigen runtime matmul function is: // // (void)(void* run_options, float* out, float* lhs, float* rhs, @@ -908,8 +1438,7 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded = - hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + bool multi_threaded = ShouldUseMultiThreadedEigen(); bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; @@ -1001,7 +1530,7 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { ir_builder_->getInt64(mat_mult_dims.k), ir_builder_->getInt32(transpose_lhs), ir_builder_->getInt32(transpose_rhs)}); - return tensorflow::Status::OK(); + return Status::OK(); } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { @@ -1018,7 +1547,9 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0, /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0, - /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1}; + /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1, + /*target_column_major=*/ + LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( @@ -1058,19 +1589,39 @@ static bool IsRank2WithNoPadding(const Shape& shape) { // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. -static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { +static bool AreValidGemmShapes( + const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, + const TargetMachineFeatures& target_machine_features) { // The inputs and the output must // 1) be matrices with no padding, and // 2) have an allowed element type. PrimitiveType output_primitive_type = output_shape.element_type(); - return (output_primitive_type == F64 || output_primitive_type == F32 || - output_primitive_type == F16) && - IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape); + if (!(output_primitive_type == F64 || output_primitive_type == F32 || + output_primitive_type == F16)) { + return false; + } + + if (!(IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && + IsRank2WithNoPadding(output_shape))) { + return false; + } + + auto is_aligned = [&](const Shape& shape) { + return GetMinimumAlignmentForArray(shape, target_machine_features) >= + TargetMachineFeatures::kEigenExpectedTensorAlignment; + }; + + if (!is_aligned(lhs_shape) || !is_aligned(rhs_shape) || + !is_aligned(output_shape)) { + return false; + } + + return true; } -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { +bool PotentiallyImplementedAsEigenDot( + const HloInstruction& hlo, + const TargetMachineFeatures& target_machine_features) { // For certain types of Dot, we can call Eigen if (hlo.opcode() == HloOpcode::kDot) { const Shape& lhs_shape = hlo.operand(0)->shape(); @@ -1087,7 +1638,8 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { // If gemm can accept the operand shapes, use it rather than a custom // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape(), + target_machine_features)) { const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index a20bf2f9db3ad3b85ec29038b48d5d0ab095197f..ed2a18976a0f1a88e7bb4632d3a63167d5c146ad 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -31,7 +31,9 @@ limitations under the License. namespace xla { namespace cpu { -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); +bool PotentiallyImplementedAsEigenDot( + const HloInstruction& hlo, + const TargetMachineFeatures& target_machine_features); // Returns the index for an operand to `hlo` that should ideally be column // major. Returns nullopt if there is no such operand or if `hlo` is not a dot @@ -55,7 +57,7 @@ class DotOpEmitter { // dimensions as the result, and the result is computed as `addend_array` + // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported // for Matrix-vector products. - static tensorflow::Status EmitDotOperation( + static Status EmitDotOperation( const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, @@ -74,18 +76,18 @@ class DotOpEmitter { const TargetMachineFeatures& target_machine_features); // Emits the IR to perform the dot operation. - tensorflow::Status Emit(); + Status Emit(); // Emits instructions to perform a scalar dot product (a multiply of the // LHS and RHS) and store the results in the target. - tensorflow::Status EmitScalarDot(); + Status EmitScalarDot(); // Emit an LLVM IR implementation of the dot operation if we can. Returns // true if an LLVM IR implementation was emitted. bool EmitLlvmIrDotIfProfitable(); // Emits a call to the CPU runtime to perform the matrix multiply. - tensorflow::Status EmitCallToRuntime(); + Status EmitCallToRuntime(); // Emits a series of nested loops for iterating over an operand array in the // dot operation. Loops are constructed in major to minor dimension layout @@ -110,17 +112,20 @@ class DotOpEmitter { // The number of columns on the RHS. int64 n; - // True if the LHS matrix column major. + // True if the LHS matrix is column major. bool lhs_column_major; // True if the LHS contraction dimension is not 1. bool lhs_non_canonical; - // True if the RHS matrix column major. + // True if the RHS matrix is column major. bool rhs_column_major; // True if the RHS contraction dimension is not 0. bool rhs_non_canonical; + + // True if the result matrix is column major. + bool target_column_major; }; // Get the MatMultDims instance for the dot product this DotOpEmitter @@ -128,6 +133,8 @@ class DotOpEmitter { // of rank 2 as well). MatMultDims GetMatMultDims() const; + bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims); + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector // registers. int64 GetGemvTilingFactor() const { @@ -136,6 +143,28 @@ class DotOpEmitter { .value_or(kDefaultTilingFactor); } + std::tuple GetGemmTileSize() const { + // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz + // + // TODO(b/80093688): Tune for other architectures and centralize this + // information in one place. + const std::tuple kDefaultTileSize = + std::tuple(11, 9, 1); + return options::LlvmIrGemmTileSize(hlo_module_config_) + .value_or(kDefaultTileSize); + } + + // Returns true if we should use an experimental implementation of GEMM + // (general matrix matrix multiplication) if possible. + bool EnableExperimentalLlvmIrGemm() const { + return options::EnableExperimentalLlvmIrGemm(hlo_module_config_); + } + + // Returns true if we should call into multi-threaded Eigen routines. + bool ShouldUseMultiThreadedEigen() { + return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + } + const HloInstruction& dot_; const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc index 7dcc4ca7fa08b478f24065275ffa69725dc51682..c56286559158758ca6db5ae097729286bde346f0 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc @@ -26,13 +26,13 @@ limitations under the License. namespace xla { namespace cpu { -void ExternalConstantPool::Insert(string name, const Literal& literal, +void ExternalConstantPool::Insert(string name, const LiteralSlice& literal, int64 alignment) { CHECK(!ShapeUtil::IsTuple(literal.shape())); CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); CHECK(entries_.find(name) == entries_.end()); - int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); + const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); void* raw_pointer = tensorflow::port::AlignedMalloc( literal_size, std::max(alignment, sizeof(void*))); CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h index 8008a56df4dbf16e7b57aee8a344058bb0d5883d..0677f5f0b58005079890052a426e5f48c5d09ed1 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -43,7 +43,7 @@ class ExternalConstantPool { // The constant pool copies out the contents of `literal` into a buffer it // owns -- it does not keep pointers to `literal`, or to memory owned by // `literal`. - void Insert(string name, const Literal& literal, int64 alignment); + void Insert(string name, const LiteralSlice& literal, int64 alignment); // Find the constant with name `name` in this constant pool. If there isn't // such constant, return nullptr. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index f209a69e3cd0f8d336d61bafd1e22be8bc88ca3f..b560b7531c0d24e6f670e61a15dce295d9fa2a49 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -24,8 +24,25 @@ limitations under the License. namespace xla { namespace cpu { +int64 GetMinimumAlignmentForArray( + const Shape& shape, const TargetMachineFeatures& target_machine_features) { + CHECK(ShapeUtil::IsArray(shape)); + CHECK(!LayoutUtil::HasLayout(shape) || LayoutUtil::IsDense(shape.layout())); + + // We don't require a layout to be set on `shape`. This only works on CPU + // because we don't pad our tensors or otherwise have complicated data tiling + // schemes. + + int64 allocation_size_bytes = + ShapeUtil::ElementsIn(shape) * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()); + return target_machine_features.minimum_alignment_for_allocation( + allocation_size_bytes); +} + bool PotentiallyImplementedAsEigenConvolution( - const HloInstruction& convolution) { + const HloInstruction& convolution, + const TargetMachineFeatures& target_machine_features) { // The following conditions are necessary (but not sufficient) for // implementing `convolution` with Eigen convolution: // - the input and kernel have a non-zero number of elements. @@ -35,6 +52,18 @@ bool PotentiallyImplementedAsEigenConvolution( // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); const Shape& kernel_shape = convolution.operand(1)->shape(); + const Shape& output_shape = convolution.shape(); + + auto is_aligned = [&](const Shape& shape) { + return GetMinimumAlignmentForArray(shape, target_machine_features) >= + TargetMachineFeatures::kEigenExpectedTensorAlignment; + }; + + if (!is_aligned(input_shape) || !is_aligned(kernel_shape) || + !is_aligned(output_shape)) { + return false; + } + if (ShapeUtil::HasZeroElements(input_shape) || ShapeUtil::HasZeroElements(kernel_shape)) { return false; @@ -71,7 +100,6 @@ bool PotentiallyImplementedAsEigenConvolution( } } - const Shape& output_shape = convolution.shape(); return dnums.input_batch_dimension() == 0 && dnums.input_feature_dimension() == input_shape.dimensions_size() - 1 && dnums.output_batch_dimension() == 0 && diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index 34b2003916933f5ec0a15d9e219063c0a912fa40..68fbc7caaa9bfec0ecd7cc7f473c8ca8afce19db 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -17,13 +17,20 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { namespace cpu { bool PotentiallyImplementedAsEigenConvolution( - const HloInstruction& convolution); + const HloInstruction& convolution, + const TargetMachineFeatures& target_machine_features); + +// Computes the minimum alignment guaranteed for a tensor of shape `shape` on +// the target machine. +int64 GetMinimumAlignmentForArray( + const Shape& shape, const TargetMachineFeatures& target_machine_features); // Dynamic loop bounds are specified as an array of dimension index // [start, limit) pairs of ir values (one for each partitioned outer dimension). diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc index 215f48c4cc1a1a6b13d98dff76e0d1f0f773f5c1..530ebce854fedf4e4db12139d5b56087b1176a6c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -34,12 +35,17 @@ ENTRY Conv { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* entry_computation = module->entry_computation(); HloInstruction* conv_instr = entry_computation->root_instruction(); - EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(*conv_instr)); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution( + *conv_instr, target_machine_features)); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 55e5aa5063d0ed0e71c6fed062e549dddc3e1e8d..59223fddac2f5f7e2e85de4d37e4b6c5760ae697 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -83,7 +83,7 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - llvm::TargetMachine* target_machine, + const TargetMachineFeatures* target_machine_features, ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), @@ -94,7 +94,7 @@ IrEmitter::IrEmitter( alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), - target_machine_features_(target_machine), + target_machine_features_(*target_machine_features), external_constant_pool_(external_constant_pool) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() @@ -160,39 +160,44 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } -llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { - llvm::GlobalVariable* result; +llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { + llvm::Constant* result; // We avoid creating large constants in the LLVM IR since LLVM is not // efficient for large constant arrays. We still emit "small enough" constant // arrays into the Ir, in the off chance the LLVM optimizer can do something // interesting with it. + // + // TODO(b/29904935): Remove the large constant pool. const int kMaxInternalConstantSizeInBytes = 128; if (external_constant_pool_ && ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { string global_name = tensorflow::strings::StrCat( "constant_global_", external_global_constant_counter_++); - result = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/IrShapeType(literal.shape()), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::ExternalLinkage, /*Initializer=*/nullptr, /*Name=*/AsStringRef(global_name)); - result->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); external_constant_pool_->Insert(global_name, literal, MinimumAlignmentForShape(literal.shape())); + result = result_global; } else { llvm::Constant* initializer = llvm_ir::ConvertLiteralToIrConstant(literal, module_); - result = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/initializer->getType(), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/initializer, /*Name=*/""); - result->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); + result = llvm::ConstantExpr::getBitCast( + result_global, IrShapeType(literal.shape())->getPointerTo()); } return result; } @@ -200,7 +205,7 @@ llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { Status IrEmitter::HandleConstant(HloInstruction* constant) { VLOG(2) << "HandleConstant: " << constant->ToString(); const Literal& literal = constant->literal(); - llvm::GlobalVariable* global_for_const; + llvm::Constant* global_for_const; auto it = emitted_literals_.find(&literal); if (it != emitted_literals_.end()) { @@ -227,32 +232,6 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { } } -// Calculate the alignment of a buffer with a particular size. -int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { - // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on - // 64-bit platforms. TCMalloc returns a pointer with alignment 8 for - // allocations smaller than kMallocAlignmentThreshold bytes and at least - // alignment 16 for allocations greater than or equal to - // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound - // by explicitly allocating the memory with posix_memalign. This is - // complicated by our desire to allow parameter buffers created by clients to - // be consumed directly by the JIT. - if (buffer_size == 0) { - // No need to align empty buffers. - return 1; - } - - const int64 kMallocAlignmentThreshold = 512; - - int pointer_size = module_->getDataLayout().getPointerSize(); - int buffer_alignment = buffer_size >= kMallocAlignmentThreshold - ? 2 * pointer_size - : pointer_size; - DCHECK_GT(buffer_alignment, 0); - - return buffer_alignment; -} - // Calculate the alignment of a buffer allocated for a given primitive type. int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); @@ -277,7 +256,7 @@ int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { DCHECK_GE(buffer_size, 0); DCHECK_LE(buffer_size, SIZE_MAX); - return MinimumAlignmentForBufferSize(buffer_size); + return target_machine_features_.minimum_alignment_for_allocation(buffer_size); } void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, @@ -290,7 +269,8 @@ void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size) { - int alignment = MinimumAlignmentForBufferSize(buffer_size); + int alignment = + target_machine_features_.minimum_alignment_for_allocation(buffer_size); if (alignment > 1) { llvm_ir::SetAlignmentMetadataForLoad(load, alignment); } @@ -530,7 +510,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32, BF16})); + /*supported_types=*/{F32, BF16, S32})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { @@ -861,7 +841,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support // different data layouts. - if (PotentiallyImplementedAsEigenConvolution(*convolution)) { + if (PotentiallyImplementedAsEigenConvolution(*convolution, + target_machine_features_)) { const Shape& lhs_shape = lhs->shape(); const Shape& rhs_shape = rhs->shape(); const Shape& convolution_shape = convolution->shape(); @@ -1027,12 +1008,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // We will accumulate the products into this sum to calculate // the output entry at the given index. PrimitiveType lhs_element_type = lhs->shape().element_type(); + llvm::Type* lhs_llvm_type = + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_), - "convolution_sum_address", &ir_builder_, + lhs_llvm_type, "convolution_sum_address", &ir_builder_, MinimumAlignmentForPrimitiveType(lhs_element_type)); - ir_builder_.CreateStore( - llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), sum_address); + llvm::Value* constant_zero = + llvm::Constant::getNullValue(lhs_llvm_type); + ir_builder_.CreateStore(constant_zero, sum_address); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_); std::vector kernel_spatial(num_spatial_dims); @@ -1186,7 +1169,13 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - const char* fn_name = runtime::kEigenFftSymbolName; + + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + const char* fn_name = multi_threaded_eigen + ? runtime::kEigenFftSymbolName + : runtime::kEigenSingleThreadedFftSymbolName; + llvm::Function* fft_func = llvm::cast( module_->getOrInsertFunction(fn_name, fft_type)); fft_func->setCallingConv(llvm::CallingConv::C); @@ -1208,16 +1197,45 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { } Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { - if (hlo_module_config_.replica_count() == 1) { - // When there is a single replica, a cross replica sum is the identity - // function, and the buffer assignment expects a copy (we could eliminate - // these at the HLO level as an optimization). - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + if (hlo_module_config_.replica_count() != 1) { + // TODO(b/33011107): Support nontrivial cross replica sum on CPU. + return Unimplemented( + "CrossReplicaSum with >1 replica is not implemented on CPU."); + } + + // When there is a single replica, a cross replica sum is the identity + // function, and the buffer assignment expects a copy. + // + // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely + // in algebraic-simplifier, but currently on some platforms + // HloModuleConfig::num_replicas changes between when the module is compiled + // and when it's run. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + + // CRS with one operand and one replica is simply the identity function. + if (crs->operand_count() == 1) { return EmitMemcpy(*crs->operand(0), *crs); } - // TODO(b/33011107): Support cross replica sum on CPU. - return Unimplemented("CrossReplicaSum is not implemented on CPU."); + // CRS with multiple operands and one replica produces a (one-deep) tuple. + std::vector operand_ptrs; + for (int64 i = 0; i < crs->operand_count(); ++i) { + llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i)); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueSlice(crs, {i})); + + const Shape& operand_shape = crs->operand(i)->shape(); + CHECK(ShapeUtil::IsArray(operand_shape)) + << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); + + // TODO(b/63762267): Be more aggressive about specifying alignment. + ir_builder_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, + ShapeUtil::ByteSizeOf(operand_shape)); + } + llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &ir_builder_, module_); + return Status::OK(); } // Fills up the free variables in 'index_with_free_var' with values from diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 5a040760804fa5609e1d68511d4b2abe8e2ec8f9..32c536e18fee86cc60067ba3b25ab1eb0e4233df 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -76,7 +76,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - llvm::TargetMachine* target_machine, + const TargetMachineFeatures* target_machine, ExternalConstantPool* external_constant_pool); ~IrEmitter() override; @@ -514,9 +514,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Calculate the alignment of a buffer allocated for a given primitive type. int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type); - // Calculate the alignment of a buffer with a particular size. - int MinimumAlignmentForBufferSize(int64 buffer_size); - // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; @@ -530,13 +527,14 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); - llvm::GlobalVariable* EmitGlobalForLiteral(const Literal& literal); + // Returns a ConstExpr bitcast. + llvm::Constant* EmitGlobalForLiteral(const Literal& literal); const HloModuleConfig& hlo_module_config_; bool is_top_level_computation_; - TargetMachineFeatures target_machine_features_; + const TargetMachineFeatures& target_machine_features_; int64 external_global_constant_counter_ = 0; ExternalConstantPool* external_constant_pool_; @@ -551,7 +549,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { } }; - tensorflow::gtl::FlatMap emitted_literals_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 47e8405ff2ea2c8aa59c66cffb2705d4ab4a6752..4fa5984b0466b178a587e97cbced97deac749f74 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -38,7 +38,7 @@ class SimpleCostModel : public ParallelCostModel { const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism_, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: @@ -63,7 +63,7 @@ class DefaultCostModel : public ParallelCostModel { int64 max_parallelism; // Calculate flops-to-bytes-ratio for 'instruction'. const int64 bytes_accessed = - std::max(1LL, cost_analysis_->bytes_accessed(*instruction)); + std::max(int64{1}, cost_analysis_->bytes_accessed(*instruction)); const float flops_to_bytes_ratio = cost_analysis_->flop_count(*instruction) / static_cast(bytes_accessed); @@ -93,7 +93,7 @@ class DefaultCostModel : public ParallelCostModel { } // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: @@ -104,7 +104,9 @@ class DefaultCostModel : public ParallelCostModel { ParallelTaskAssignment::ParallelTaskAssignment( const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) { + const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module, + const TargetMachineFeatures* target_machine_features) + : target_machine_features_(*target_machine_features) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. auto cost_analysis = MakeUnique(shape_size); @@ -139,8 +141,10 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || (opcode == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) || - PotentiallyImplementedAsEigenDot(*instruction) || + PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_)) || + PotentiallyImplementedAsEigenDot(*instruction, + target_machine_features_) || (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || ShapeUtil::IsTuple(instruction->shape())) { @@ -231,7 +235,8 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( void ParallelTaskAssigner::ComputeTargetParallelTasks( HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { ParallelTaskAssignment parallel_task_assignment(max_parallelism_, - shape_size_function_, module); + shape_size_function_, module, + &target_machine_features_); // Compute parallel task counts for all instructions in 'module'. for (auto* computation : module->computations()) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 7140dabe516cd7ea9260456e994e8b63b68c60d6..8becc8fa23424d7454cc783eb9d853aecb5d053b 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -39,7 +40,8 @@ class ParallelTaskAssignment { // 'module': the containing HloModule. ParallelTaskAssignment(const int64 max_parallelism, const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module); + HloModule* module, + const TargetMachineFeatures* target_machine_features); ~ParallelTaskAssignment() {} // Computes and returns the target parallel task count for 'instruction'. @@ -47,6 +49,7 @@ class ParallelTaskAssignment { private: std::unique_ptr cost_model_; + const TargetMachineFeatures& target_machine_features_; }; // ParallelTaskAssigner computes target parallel task counts for all HLOs @@ -63,8 +66,11 @@ class ParallelTaskAssigner : public HloPassInterface { // 'shape_size': shape size function used by HloCostAnalysis during parallel // task assignment. ParallelTaskAssigner(const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size) - : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {} + const HloCostAnalysis::ShapeSizeFunction& shape_size, + const TargetMachineFeatures* target_machine_features) + : max_parallelism_(max_parallelism), + shape_size_function_(shape_size), + target_machine_features_(*target_machine_features) {} ~ParallelTaskAssigner() override {} tensorflow::StringPiece name() const override { @@ -94,6 +100,7 @@ class ParallelTaskAssigner : public HloPassInterface { int64 max_parallelism_; HloCostAnalysis::ShapeSizeFunction shape_size_function_; + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 13eb75a57213b1a68a5732a4f6061efdf97fa4f4..fc2efbaf9a22b02cd729da2f367d53bc15506836 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,6 +32,19 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { // Use any value larger than 2 since we only test whether a module is // parallelized or not const int max_parallelism_ = 10; + + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; + + ParallelTaskAssignmentTest() + : target_machine_features_([](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }) {} + + StatusOr RunParallelTaskAssigner(HloModule* module) { + return cpu::ParallelTaskAssigner(max_parallelism_, shape_size_func_, + &target_machine_features_) + .Run(module); + } }; TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { @@ -45,9 +59,7 @@ TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -74,9 +86,7 @@ TEST_F(ParallelTaskAssignmentTest, )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -92,9 +102,7 @@ TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -108,9 +116,7 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 984cb0616e02475babad7160d0f43bb23de0b50e..0bf693edd0b985a4e62c16414646cc6a17db26ee 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -21,8 +21,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/types.h" // 'tensorflow' namespace is used so that int64 and other types don't require @@ -71,11 +69,9 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = fft_shape[i]; out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -88,8 +84,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); // Compute the full FFT using a temporary tensor. - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(in_dims); + const Eigen::DSizes zero_start_indices; full_fft.device(device) = input.template fft(axes); @@ -112,11 +108,9 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; out_dims[i + 1] = fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -129,8 +123,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, // region we will slice from input given fft_shape. We slice input to // fft_shape on its inner-most dimensions, except the last (which we // slice to fft_shape[-1] / 2 + 1). - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(out_dims); // Calculate the starting point and range of the source of // negative frequency part. @@ -179,7 +172,6 @@ template void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, int32 fft_type, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { - CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; switch (fft_type) { case ::xla::FftType::FFT: EigenFftC2C( @@ -204,7 +196,8 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT type: " << fft_type; + // Unsupported FFT type + abort(); } } @@ -230,7 +223,8 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT rank " << fft_rank; + // Unsupported FFT rank + abort(); } } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc index 92da5f71c23d5e1450b39ea8b7bb8345f6fabb3b..f8c8dd5e93d53db8d87be0208b5cf4daac3464f1 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "third_party/intel_mkl_ml/include/mkl_cblas.h" #include "third_party/intel_mkl_ml/include/mkl_service.h" diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc new file mode 100644 index 0000000000000000000000000000000000000000..2613ddb12704aea7d0884c6c8c062dc028383639 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, fft_type, + fft_rank, input_batch, fft_length0, fft_length1, + fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h new file mode 100644 index 0000000000000000000000000000000000000000..dcd133d012cf074a4cd2f550585881388bea6156 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc index 42fe955f1917e0268dc739e44fbd0a7afb39185c..d12c5396148d32adb178b955a34e050cc56784da 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -115,7 +115,7 @@ ShapePartitionIterator::ShapePartitionIterator( for (int i = 0; i < dimension_partition_sizes_.size(); ++i) { const int64 dim_size = shape_.dimensions(dimensions_[i]); dimension_partition_sizes_[i] = - std::max(1LL, dim_size / dimension_partition_counts_[i]); + std::max(int64{1}, dim_size / dimension_partition_counts_[i]); } // Calculate the partition strides for each dimension. diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index ff6f0a9d4e443c2ed7d2dd6c58f4aaf28205b0cb..c4c90515ac7ec2721cb9ea48d42e3c5080e249af 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/types.h" @@ -73,23 +74,33 @@ llvm::StringRef GetHostCpuName() { } } // namespace +/*static*/ std::unique_ptr +SimpleOrcJIT::InferTargetMachineForJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level) { + std::unique_ptr target_machine( + llvm::EngineBuilder() + .setTargetOptions(target_options) + .setOptLevel(opt_level) + .selectTarget( + /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", + /*MCPU=*/GetHostCpuName(), + /*MAttrs=*/DetectMachineAttributes())); + CHECK(target_machine != nullptr); + return target_machine; +} + SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, bool enable_fast_math, bool disable_expensive_passes, LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook) - : target_machine_( - CHECK_NOTNULL(llvm::EngineBuilder() - .setTargetOptions(target_options) - .setOptLevel(opt_level) - .selectTarget( - /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", - /*MCPU=*/GetHostCpuName(), - /*MAttrs=*/DetectMachineAttributes()))), + : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), symbol_resolver_(llvm::orc::createLegacyLookupResolver( + execution_session_, [this](const std::string& name) -> llvm::JITSymbol { return this->ResolveRuntimeSymbol(name); }, @@ -192,6 +203,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index f4260a95bc45557b6cd969f7d3fff01c8b392575..1851a3ee0bb97b4860605d7211a6ae70ac88686b 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -95,6 +95,12 @@ class SimpleOrcJIT { return &external_constant_pool_; } + // Creates an llvm::TargetMachine suitable for JITting code that will run on + // the current machine. + static std::unique_ptr InferTargetMachineForJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level); + private: llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc index eeb049737dddd11ef2ce229df772baec3ac03dd8..a0cd8ee2d2be10bcee9c2e216e24908d949e2d7b 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { namespace cpu { -llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( +llvm::TargetTransformInfo* LLVMTargetMachineFeatures::GetTargetTransformInfoFor( const llvm::Function& function) const { auto it = target_transform_info_cache_.find(&function); if (it == target_transform_info_cache_.end()) { @@ -31,5 +31,30 @@ llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( return &it->second; } +int64 LLVMTargetMachineFeatures::minimum_alignment_for_allocation( + int64 size_bytes) const { + // GLibc malloc returns a pointer with alignment 8 on 32-bit platforms and 16 + // on 64-bit platforms. TCMalloc returns a pointer with alignment 8 for + // allocations smaller than kMallocAlignmentThreshold bytes and at least + // alignment 16 for allocations greater than or equal to + // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound + // by explicitly allocating the memory with posix_memalign. This is + // complicated by our desire to allow parameter buffers created by clients to + // be consumed directly by the JIT. + if (size_bytes == 0) { + // No need to align empty buffers. + return 1; + } + + const int64 kMallocAlignmentThreshold = 512; + + int pointer_size = target_machine_->getPointerSize(0); + int buffer_alignment = + size_bytes >= kMallocAlignmentThreshold ? 2 * pointer_size : pointer_size; + DCHECK_GT(buffer_alignment, 0); + + return buffer_alignment; +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h index 703942615e552dccde7ddec8c8b90e8a486652af..8b00ae9e47eeed26ffe80707b89593b267e8dbb8 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -24,43 +24,68 @@ limitations under the License. namespace xla { namespace cpu { -// Wraps an llvm::TargetMachine and parses out some information that feeds into -// LLVM IR code generation decisions. +// Abstract interface for classes providing information about the target we're +// compiling for. class TargetMachineFeatures { public: static constexpr int kX86AvxVectorByteSize = 32; - TargetMachineFeatures(llvm::TargetMachine* target_machine) - : target_machine_(target_machine) {} + // Input and output tensor buffers must be aligned to this many bytes if we + // want to call an Eigen backed GEMM or Convolution. + static constexpr int kEigenExpectedTensorAlignment = 16; // Return the vectorization factor, which is the number of bytes of data // explicitly vectorized routines will try to process at once. - int vectorization_factor_in_bytes() const { - // Ideally this should be a function of the cache line size (which we can - // get from llvm::TargetTransformInfo::getCacheLineSize) of the target - // machine. Guess a value of 128 bytes for now. - return 128; - } + virtual int vectorization_factor_in_bytes() const = 0; // Return the size of the largest vector size in bytes. We need to pass in // "function" since llvm functions can contain annotations for specializing // them to specific micro-architectures (though currently XLA does not use // this functionality). - int vector_register_byte_size(const llvm::Function& function) const { - llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); - return tti->getRegisterBitWidth(/*Vector=*/true) / 8; - } + virtual int vector_register_byte_size( + const llvm::Function& function) const = 0; // Return the number of elements of type `type` that can fit into the largest // vector register available. We need to pass in "function" since llvm // functions can contain annotations for specializing them to specific // micro-architectures (though currently XLA does not use this functionality). + virtual int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const = 0; + + // Returns the minimum alignment for a buffer of size size_bytes. + virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0; + + virtual ~TargetMachineFeatures() = default; +}; + +// Implements the TargetMachineFeatures interface using an llvm::TargetMachine. +class LLVMTargetMachineFeatures : public TargetMachineFeatures { + public: + static constexpr int kX86AvxVectorByteSize = 32; + + LLVMTargetMachineFeatures(llvm::TargetMachine* target_machine) + : target_machine_(target_machine) {} + + int vectorization_factor_in_bytes() const override { + // Ideally this should be a function of the cache line size (which we can + // get from llvm::TargetTransformInfo::getCacheLineSize) of the target + // machine. Guess a value of 128 bytes for now. + return 128; + } + + int vector_register_byte_size(const llvm::Function& function) const override { + llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); + return tti->getRegisterBitWidth(/*Vector=*/true) / 8; + } + int vector_register_num_elements(const llvm::Function& function, - PrimitiveType type) const { + PrimitiveType type) const override { return vector_register_byte_size(function) / (primitive_util::BitWidth(type) / 8); } + int64 minimum_alignment_for_allocation(int64 size_bytes) const override; + private: llvm::TargetTransformInfo* GetTargetTransformInfoFor( const llvm::Function& function) const; diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h new file mode 100644 index 0000000000000000000000000000000000000000..ffc6927cbe1a2b6fd1a1ca3aac9b6e047741c2af --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" + +namespace xla { +namespace cpu { +// Delegates calls to minimum_alignment_for_allocation to a user provided +// std::function, crashes on all other methods. +// +// Primarily useful for testing. +class TargetMachineFeaturesWithFakeAlignmentLogic + : public TargetMachineFeatures { + public: + explicit TargetMachineFeaturesWithFakeAlignmentLogic( + std::function fake_alignment_logic) + : fake_alignment_logic_(std::move(fake_alignment_logic)) {} + + int vectorization_factor_in_bytes() const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int vector_register_byte_size(const llvm::Function& function) const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int64 minimum_alignment_for_allocation(int64 size_bytes) const override { + return fake_alignment_logic_(size_bytes); + } + + private: + std::function fake_alignment_logic_; +}; +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 18a915e5339623c73fee0e339fe75ee405898a36..66ae5ef0f66e90982102d73e474f5d0582f5415c 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -32,7 +32,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -153,9 +152,9 @@ tf_cc_test( srcs = ["cpu_literal_caching_test.cc"], deps = [ "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -167,9 +166,9 @@ tf_cc_test( srcs = ["cpu_outfeed_test.cc"], deps = [ "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index ed8f375bd6186e4805fe9ded5be9ae7c9f4d5c84..faac927027c48e44eb8ff1fcc4109fbc177fc579 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -64,8 +64,8 @@ TEST_F(CpuExternalConstantsTest, BasicNegative) { // The constant array in this test case is small enough that there is no need // to externalize it. TestWithArray(/*rows=*/4, /*cols=*/4, R"( -CHECK-NOT: @constant_global_0 = external constant [4 x [4 x float]], align 8 -CHECK: @0 = private constant [4 x [4 x float]] {{.*}}, align 8 +CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8 +CHECK: @0 = private constant [16 x float] {{.*}}, align 8 )"); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index d6e0425c5542be89835571f0103b1829f63cc2c2..27044b1d62027e3b83744c486cb790269e505aff 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" namespace xla { namespace cpu { @@ -55,12 +55,12 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [2 x [3 x [2 x float]]] -CHECK-NOT: private constant [2 x [3 x [2 x float]]] +CHECK: private constant [12 x float] +CHECK-NOT: private constant [12 x float] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", @@ -78,34 +78,34 @@ TEST_F(CpuDuplicateConstantsTest, RepeatedTupleConstants) { HloModule RepeatedConstants while_body { - arg_body = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) - ROOT const = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) + arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) } while_cond { - arg_cond = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) + arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) ROOT unknown = pred[] infeed() } ENTRY main { param = f32[2,3,2] parameter(0) - const_a = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) - const_b = (f32[2,1]{1,0}, f32[2]{0}) while((f32[2,1]{1,0}, f32[2]{0}) const_a), condition=while_cond, body=while_body + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - out0 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_a) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b) + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b) } )"; string filecheck_pattern = R"( +CHECK: private constant [1 x float] CHECK: private constant [2 x float] -CHECK: private constant [2 x [1 x float]] +CHECK-NOT: private constant [1 x float] CHECK-NOT: private constant [2 x float] -CHECK-NOT: private constant [2 x [1 x float]] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index 879372eb13884cdb7edd8cfb3e8b4bac4e314951..1ee279290b6fcfe775ce9867d424b1c031f5d2bd 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" namespace xla { namespace cpu { @@ -37,11 +37,11 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [2 x [3 x [2 x float]]] +CHECK: private constant [12 x float] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index cd1165e23812861ba9951546b7dd744529232196..c444d151858d3a152a01b99657ffae89ebc6b487 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -427,5 +427,27 @@ llvm::Value* LlvmVariable::Get() const { void LlvmVariable::Set(llvm::Value* new_value) { ir_builder_->CreateStore(new_value, alloca_); } + +TileVariable::TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value) { + for (llvm::Value* initial_vector_value : initial_value) { + storage_.emplace_back(vector_support, initial_vector_value); + } +} + +std::vector TileVariable::Get() const { + std::vector result; + c_transform(storage_, std::back_inserter(result), + [&](VectorVariable vect_var) { return vect_var.Get(); }); + return result; +} + +void TileVariable::Set(tensorflow::gtl::ArraySlice value) { + CHECK_EQ(value.size(), storage_.size()); + for (int64 i = 0, e = value.size(); i < e; i++) { + storage_[i].Set(value[i]); + } +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 6479bf76aab581ae3ec2923d98dab53720cab203..49c2a4e2f4bae9e1672b7d2fe891301bce08bd4b 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -143,6 +144,12 @@ class VectorSupportLibrary { llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, llvm::Value* offset_elements); + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + llvm::Value* offset_elements, int64 scale) { + return ComputeOffsetPointer( + base_pointer, + ir_builder_->CreateMul(ir_builder_->getInt64(scale), offset_elements)); + } llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, int64 offset_elements) { return ComputeOffsetPointer(base_pointer, @@ -311,6 +318,21 @@ class ScalarVariable : public LlvmVariable { Set(initial_value); } }; + +// This wraps a set of alloca-backed stack variables that can, as a whole, store +// a tile. A "tile" is a sequence of vectors that is typically used as a 2D +// grid of scalar values (e.g. for tiled GEMMs). +class TileVariable { + public: + TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value); + + std::vector Get() const; + void Set(tensorflow::gtl::ArraySlice value); + + private: + std::vector storage_; +}; } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index 35db4fd2a22cc1615ade77a801cb28c504db09a6..e228bb56bce8febcca28ae171f6de90973d020ab 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -29,7 +29,7 @@ StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( : DeviceMemoryAllocator(platform), stream_executors_(stream_executors.begin(), stream_executors.end()) {} -StatusOr StreamExecutorMemoryAllocator::Allocate( +StatusOr StreamExecutorMemoryAllocator::Allocate( int device_ordinal, uint64 size, bool retry_on_failure) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); @@ -40,22 +40,17 @@ StatusOr StreamExecutorMemoryAllocator::Allocate( tensorflow::strings::HumanReadableNumBytes(size).c_str(), size, device_ordinal); } - return result; + return OwningDeviceMemory(result, device_ordinal, this); } -tensorflow::Status StreamExecutorMemoryAllocator::Deallocate( - int device_ordinal, se::DeviceMemoryBase* mem) { - if (!mem->is_null()) { +Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal, + se::DeviceMemoryBase mem) { + if (!mem.is_null()) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); - // We make a local copy of 'mem' so the original is not zeroed out by the - // Deallocate() call below. This gives us a better chance of - // catching double-free bugs, since Deallocate silently succeeds for null - // values. - se::DeviceMemoryBase mem_copy(*mem); - stream_executor->Deallocate(&mem_copy); + stream_executor->Deallocate(&mem); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index da45c4d45a1c56fd39b1e3e17ff131de59ceeced..d87b86caf0d3acaa5bf9a455cff2315cedb2496d 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -37,28 +38,29 @@ class DeviceMemoryAllocator { : platform_(platform) {} virtual ~DeviceMemoryAllocator() {} + // Allocates memory on the device. + // + // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory + // must not be null. If size == 0, must return a null OwningDeviceMemory. + // // 'retry_on_failure': If false, and the first attempt to allocate the memory // fails, the allocation should return immediately without retrying. An // example use case is optional scratch spaces where a failure has only // performance impact. - // - // Allocate() should return a null pointer for a size-0 allocation. - // Deallocate() must be a no-op for null pointers. - virtual StatusOr Allocate(int device_ordinal, - uint64 size, - bool retry_on_failure) = 0; + virtual StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) = 0; // Two-arg version of Allocate(), which sets retry-on-failure to true. // // (We don't simply use a default argument on the virtual Allocate function // because default args on virtual functions are disallowed by the Google // style guide.) - StatusOr Allocate(int device_ordinal, uint64 size) { + StatusOr Allocate(int device_ordinal, uint64 size) { return Allocate(device_ordinal, size, /*retry_on_failure=*/true); } - virtual tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) = 0; + // Must be a nop for null pointers. + virtual Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) = 0; // Return the platform that the allocator allocates memory on. const se::Platform* platform() const { return platform_; } @@ -68,6 +70,7 @@ class DeviceMemoryAllocator { virtual bool AllowsAsynchronousDeallocation() const = 0; protected: + friend class OwningDeviceMemory; const se::Platform* platform_; }; @@ -79,14 +82,13 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { const se::Platform* platform, tensorflow::gtl::ArraySlice stream_executors); - StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; + StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) override; // Pull in two-arg overload that sets retry_on_failure to true. using DeviceMemoryAllocator::Allocate; - tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; bool AllowsAsynchronousDeallocation() const override; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 0528b076027603796a445d8b0e9cbcebd1b513a7..ee2b455730f8f520db6652f0352f8a96291cac73 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -138,6 +138,9 @@ class DfsHloVisitorBase { virtual Status HandleExp(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleExpm1(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleFloor(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -150,6 +153,9 @@ class DfsHloVisitorBase { virtual Status HandleClz(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleLog1p(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleCos(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -191,6 +197,10 @@ class DfsHloVisitorBase { return HandleElementwiseUnary(hlo); } + virtual Status HandleDomain(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; @@ -233,6 +243,8 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; + virtual Status HandleGenerateToken(HloInstructionPtr token) = 0; + // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". virtual Status FinishVisit(HloInstructionPtr root) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 240faebe62f5cee4f61b3c36b5e8f653cfd6db8e..6934e00a4b665e9e6a4302e0c0a8ce1d5bb94373 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -188,6 +188,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } + Status HandleGenerateToken(HloInstructionPtr token) override { + return DefaultAction(token); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index ae32d33766093cf4e610a0dc05f7d8c88cb37d31..93fea7ead7a86bb34c449668fd88a58145681eb1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -418,8 +418,12 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } case HloOpcode::kExp: return EmitExp(op->shape().element_type(), operand_value); + case HloOpcode::kExpm1: + return EmitExpm1(op->shape().element_type(), operand_value); case HloOpcode::kLog: return EmitLog(op->shape().element_type(), operand_value); + case HloOpcode::kLog1p: + return EmitLog1p(op->shape().element_type(), operand_value); case HloOpcode::kCos: return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: @@ -452,17 +456,15 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { - // (x == x) && abs(x) != inf + // abs(x) o!= inf, this works because the comparison returns false if + // either operand is NaN. auto type = operand_value->getType(); - auto equal_self = - ir_builder_->CreateFCmpOEQ(operand_value, operand_value); auto abs_value = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_); auto infinity = llvm::ConstantFP::getInfinity(type); auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); - auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite); return ir_builder_->CreateZExt( - result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: return ir_builder_->CreateFNeg(operand_value); @@ -493,6 +495,22 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComposeComplex( op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); } + case HloOpcode::kLog1p: { + // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + llvm::Type* llvm_ty = a->getType(); + auto one = llvm::ConstantFP::get(llvm_ty, 1.0); + auto a_plus_one = ir_builder_->CreateFAdd(a, one); + auto sum_sq = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(a_plus_one, a_plus_one), + ir_builder_->CreateFMul(b, b)); + TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); + TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); + auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); TF_RET_CHECK(primitive_util::IsComplexType(from_type)); @@ -523,6 +541,20 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), ir_builder_->CreateFMul(exp_a, sin_b)); } + case HloOpcode::kExpm1: { + // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value))); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); + auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); + auto real_result = + ir_builder_->CreateFSub(ir_builder_->CreateFMul(exp_a, cos_b), one); + auto imag_result = ir_builder_->CreateFMul(exp_a, sin_b); + return EmitComposeComplex(op, real_result, imag_result); + } case HloOpcode::kCos: { // cos(z) = .5(e^(iz) + e^(-iz)) // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai)) @@ -975,6 +1007,28 @@ StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, {value->getType()}, ir_builder_); } +StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const { + auto x = value; + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto one = llvm::ConstantFP::get(type, 1.0); + auto negative_half = llvm::ConstantFP::get(type, -0.5); + // When x is large, the naive evaluation of ln(x + 1) is more + // accurate than the Taylor series. + TF_ASSIGN_OR_RETURN(auto for_large_x, + EmitLog(prim_type, ir_builder_->CreateFAdd(x, one))); + // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. + auto for_small_x = ir_builder_->CreateFMul( + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(negative_half, x), one), + x); + const auto kAntilogarithmIsSmallThreshold = 1e-4; + auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, + {type}, ir_builder_); + auto x_is_small = ir_builder_->CreateFCmpOLT( + abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); + return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); +} + StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, @@ -993,6 +1047,29 @@ StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, {value->getType()}, ir_builder_); } +StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const { + auto x = value; + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto one = llvm::ConstantFP::get(type, 1.0); + auto half = llvm::ConstantFP::get(type, 0.5); + // When the exponent is large, the naive evaluation of e^(x) - 1 is more + // accurate than the Taylor series. + TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); + auto for_large_x = ir_builder_->CreateFSub(exp_x, one); + // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. + // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. + auto x_squared = ir_builder_->CreateFAdd(x, x); + auto x_squared_over_two = ir_builder_->CreateFMul(x_squared, half); + auto for_small_x = ir_builder_->CreateFAdd(x, x_squared_over_two); + const auto kExponentIsSmallThreshold = 1e-5; + auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, + {type}, ir_builder_); + auto x_is_small = ir_builder_->CreateFCmpOLT( + abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); + return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); +} + StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { @@ -1468,6 +1545,26 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(hlo->operand(1))(dim_index)); + + // Clamp the start index so that the sliced portion fits in the operand: + // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, + index[i]->getType()); + llvm::Value* operand_dim_size = llvm::ConstantInt::get( + start_index_value->getType(), input_hlo->shape().dimensions(i)); + llvm::Value* output_dim_size = llvm::ConstantInt::get( + start_index_value->getType(), hlo->shape().dimensions(i)); + + start_index_value = EmitIntegralMin( + ir_builder_->CreateSub(operand_dim_size, output_dim_size), + EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), + start_index_value, /*is_signed=*/true), + /*is_signed=*/true); + start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; @@ -1476,14 +1573,8 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( llvm_ir::IrArray::Index input_index(rank); for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: - // input_index = (start_index + offset_index) % dim_size - // Security note: this is the code that keeps the indices in-bounds. - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast( - slice_start_index[i], index[i]->getType()); - input_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateAdd(start_index, index[i]), dim_size); + // input_index = start_index + offset_index + input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]); } return operand_to_generator.at(input_hlo)(input_index); } @@ -1582,104 +1673,48 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const int64 rank = ShapeUtil::Rank(input_hlo->shape()); llvm_ir::IrArray::Index slice_start_index(rank); llvm_ir::IrArray::Index slice_limit_index(rank); - // Slice starts at update[index - slice_start_index_adjusted], - // where adjusted value = slice_start_index when in bounds, and - // adjusted value = slice_start_index - input_dim, when wrapping. - llvm_ir::IrArray::Index slice_start_index_adjusted(rank); - // Slice intersection gathers (ANDs) conditions on all ranks for which // 'input' is set to 'update' llvm::Value* slice_intersection = ir_builder_->getTrue(); for (int64 i = 0; i < rank; ++i) { - // Emit IR to read dynamic start indices from 'start_hlo'. llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(start_hlo)(dim_index)); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( - start_index_value, index[i]->getType()); + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, + index[i]->getType()); llvm::Value* input_dim_size = llvm::ConstantInt::get( index[i]->getType(), input_hlo->shape().dimensions(i)); llvm::Value* update_dim_size = llvm::ConstantInt::get( index[i]->getType(), update_hlo->shape().dimensions(i)); - // Generate code to handle wrapping semantics: - // slice_start_index[i] = slice_start_index[i] % input_dim_size; - // slice_limit_index[i] = slice_start_index[i] + update_dim_size. - // slice_start_index[i] is updated in place and it will now be in - // range. slice_limit_index[i] may be out of range, and it's being - // URem-ed below if so. - slice_start_index[i] = - ir_builder_->CreateURem(slice_start_index[i], input_dim_size); + start_index_value = EmitIntegralMin( + ir_builder_->CreateSub(input_dim_size, update_dim_size), + EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), + start_index_value, /*is_signed=*/true), + /*is_signed=*/true); + + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); + slice_start_index[i] = start_index_value; slice_limit_index[i] = ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); - // Test if slice_limit_index[i] is in bounds - llvm::Value* in_bounds = - ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size); - llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - - // Handle true BB (slice_limit_index[i] <= input_dim_size). - SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] && - // index[i] < slice_limit_index[i] - llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd( + slice_intersection = ir_builder_->CreateAnd( slice_intersection, ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection_in"); - slice_intersection_in_bounds = ir_builder_->CreateAnd( - slice_intersection_in_bounds, + "slice_intersection"); + slice_intersection = ir_builder_->CreateAnd( + slice_intersection, ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection_in"); - - // Handle false BB (slice_limit_index[i] > input_dim_size). - SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] || - // index[i] < slice_limit_index[i]%input_dim_size. - llvm::Value* index_wraps = ir_builder_->CreateICmpSLT( - index[i], - ir_builder_->CreateURem(slice_limit_index[i], input_dim_size)); - llvm::Value* slice_intersection_or = ir_builder_->CreateOr( - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), index_wraps, - "slice_intersection_out"); - llvm::Value* slice_intersection_out_of_bounds = ir_builder_->CreateAnd( - slice_intersection, slice_intersection_or, "slice_intersection_out"); - // Create value for slice_start_index_adjusted[i] when out of bounds. - // If within out-of-bounds if. - llvm_ir::LlvmIfData if_start_needs_adjustment = - llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_); - SetToFirstInsertPoint(if_start_needs_adjustment.true_block, ir_builder_); - llvm::Value* slice_start_index_adjusted_oob = - ir_builder_->CreateSub(slice_start_index[i], input_dim_size); - SetToFirstInsertPoint(if_start_needs_adjustment.after_block, ir_builder_); - llvm::PHINode* slice_start_index_adjusted_phi = - ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), 2); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index_adjusted_oob, if_start_needs_adjustment.true_block); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index[i], if_start_needs_adjustment.false_block); - // End of if within if. - - // After checking in/out of bounds. - SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_); - llvm::PHINode* phi_slice_intersection = - ir_builder_->CreatePHI(slice_intersection->getType(), 2); - phi_slice_intersection->addIncoming(slice_intersection_in_bounds, - if_in_bounds.true_block); - phi_slice_intersection->addIncoming(slice_intersection_out_of_bounds, - if_start_needs_adjustment.after_block); - slice_intersection = phi_slice_intersection; - - llvm::PHINode* phi_index = - ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2); - phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block); - phi_index->addIncoming(slice_start_index_adjusted_phi, - if_start_needs_adjustment.after_block); - slice_start_index_adjusted[i] = phi_index; + "slice_intersection"); } // Emit: @@ -1696,12 +1731,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Compute update index for intersection case. llvm_ir::IrArray::Index update_index(rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - // NOTE: Subtraction will be positive due to bounds checking above. - update_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]), - update_dim_size); + update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); @@ -1784,8 +1814,13 @@ StatusOr ElementalIrEmitter::EmitElementalDot( const llvm_ir::IrArray::Index& dot_result_index) const { auto lhs_generator = operand_to_generator.at(hlo->operand(0)); auto rhs_generator = operand_to_generator.at(hlo->operand(1)); - int64 contracted_dim_size = hlo->operand(0)->shape().dimensions( - hlo->operand(0)->shape().dimensions_size() - 1); + + const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers(); + int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0); + int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0); + + int64 contracted_dim_size = + hlo->operand(0)->shape().dimensions(lhs_contracting_dim); int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); @@ -1816,13 +1851,12 @@ StatusOr ElementalIrEmitter::EmitElementalDot( for (int64 i = 0; i < lhs_dims - 1; i++) { lhs_index.push_back(dot_result_index[i]); } - lhs_index.push_back(inner_loop->GetIndVarValue()); + lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue()); - for (int64 i = 0; i < rhs_dims - 2; i++) { + for (int64 i = 0; i < rhs_dims - 1; i++) { rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); } - rhs_index.push_back(inner_loop->GetIndVarValue()); - rhs_index.push_back(dot_result_index.back()); + rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); llvm::Value* current_accumulator = ir_builder_->CreateLoad(accumulator_alloca); @@ -1877,10 +1911,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kReal: diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 26dff0d96f1d0f00fcdf12410a3769d18a186773..d199473374ad394913413a7d3fe805f8782936f7 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -105,6 +105,9 @@ class ElementalIrEmitter { virtual StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitSin(PrimitiveType prim_type, llvm::Value* value) const; @@ -114,6 +117,9 @@ class ElementalIrEmitter { virtual StatusOr EmitExp(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8980d4303353a132ada2b3c685b4f2856c33c6a1 --- /dev/null +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class ElementalIrEmitterExecutionTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) { + const string hlo_text = R"( +HloModule FusedDot + +fused_computation { + arg0 = s32[1,2,1]{2,1,0} parameter(0) + reshape.lhs = s32[2,1]{1,0} reshape(arg0) + arg1 = s32[1,2,1]{2,1,0} parameter(1) + reshape.rhs = s32[2,1]{1,0} reshape(arg1) + ROOT dot = s32[1,1]{1,0} dot(reshape.lhs, reshape.rhs), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY main { + entry_arg0 = s32[1,2,1]{2,1,0} parameter(0) + entry_arg1 = s32[1,2,1]{2,1,0} parameter(1) + ROOT fusion = s32[1,1]{1,0} fusion(entry_arg0, entry_arg1), kind=kLoop, calls=fused_computation +} +)"; + + std::unique_ptr lhs = Literal::CreateR3({{{1}, {2}}}); + std::unique_ptr rhs = Literal::CreateR3({{{3}, {4}}}); + RunTest(hlo_text, {lhs.get(), rhs.get()}); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 8119478ce934da06969024905e5e054e0b509b03..6df172db8e541c5cef7aab04f3d8611fc735e8b0 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -129,20 +129,6 @@ StatusOr Executable::ExecuteOnStreamWrapper( return return_value; } -Status Executable::DumpSessionModule() { - TF_RET_CHECK(dumping()); - const string& directory_path = - module_config().debug_options().xla_dump_executions_to(); - VersionedComputationHandle versioned_handle = entry_computation_handle(); - // This filename does not include the version number because the computation - // is only ever executed at one version. - string filename = tensorflow::strings::Printf( - "computation_%lld__%s__execution_%lld", versioned_handle.handle.handle(), - session_module_->entry().name().c_str(), ++execution_count_); - return Executable::DumpToDirectory(directory_path, filename, - *session_module_); -} - Status Executable::DumpHloSnapshot() { TF_RET_CHECK(dumping_snapshot()); TF_RET_CHECK(hlo_snapshot_->has_hlo() && @@ -156,26 +142,6 @@ Status Executable::DumpHloSnapshot() { return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); } -/* static */ Status Executable::DumpToDirectory( - const string& directory_path, string filename, - const SessionModule& session_module) { - tensorflow::Env* env = tensorflow::Env::Default(); - if (!env->IsDirectory(directory_path).ok()) { - // NB! CreateDir does not work reliably with multiple XLA threads -- two - // threads can race to observe the absence of the dump directory and - // simultaneously try to create it, causing the "losing" thread to get a - // "directory already exists" error. - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path)); - } - filename = SanitizeFileName(std::move(filename)); - string file_path = tensorflow::io::JoinPath(directory_path, filename); - string result; - TF_RET_CHECK( - tensorflow::SerializeToStringDeterministic(session_module, &result)); - return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, - result); -} - /* static */ Status Executable::DumpToDirectory( const string& directory_path, string filename, const HloSnapshot& hlo_session) { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 4f0466c544738fa1ec4602ee5104daee8d969c83..dc1f26ea65cc707d4f0522af2aa3ec40621632f1 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -27,9 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -132,26 +130,12 @@ class Executable { const HloModuleConfig& module_config() const { return hlo_module_->config(); } - // Returns the versioned computation handle of the computation computed by - // this executable. - const VersionedComputationHandle& entry_computation_handle() const { - return hlo_module_->entry_computation_handle(); - } - // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. const Shape& host_result_shape() const { return hlo_module_->config().host_entry_computation_layout().result_shape(); } - // TODO(b/74197823): Delete the session module dumping helpers. - void set_session_module(std::unique_ptr session_module) { - session_module_ = std::move(session_module); - } - bool dumping() const { return session_module_ != nullptr; } - SessionModule* session_module() const { return session_module_.get(); } - Status DumpSessionModule(); - // Dumping helpers. void set_hlo_snapshot(std::unique_ptr hlo_snapshot) { hlo_snapshot_ = std::move(hlo_snapshot); @@ -160,10 +144,6 @@ class Executable { HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); } Status DumpHloSnapshot(); - // Dump session_module to directory_path/filename. - static Status DumpToDirectory(const string& directory_path, string filename, - const SessionModule& session_module); - // Dump hlo snapshot to directory_path/filename. static Status DumpToDirectory(const string& directory_path, string filename, const HloSnapshot& hlo_session); @@ -179,9 +159,6 @@ class Executable { // around. const std::unique_ptr hlo_module_; - // SessionModule this was compiled from. Null if not dumping executions. - std::unique_ptr session_module_; - // HloSnapshot this was compiled from. Null if not dumping executions. std::unique_ptr hlo_snapshot_; diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 2f0b9ed2bd98fbea4e67c0a30d5aa41ff6a06979..6794cfe297b0fb9a15eb9b7e6906d225f9597d07 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -37,11 +37,11 @@ AsyncExecution::AsyncExecution(Backend* backend, } } -tensorflow::Status AsyncExecution::BlockUntilDone() const { +Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); } - return tensorflow::Status::OK(); + return Status::OK(); } ExecutionTracker::ExecutionTracker() : next_handle_(1) {} @@ -61,7 +61,7 @@ ExecutionHandle ExecutionTracker::Register( return execution_handle; } -tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { +Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { @@ -69,7 +69,7 @@ tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { handle.handle()); } handle_to_execution_.erase(handle.handle()); - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr ExecutionTracker::Resolve( diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 5b6bddf9f16a85f7863f4d05c39c7d4c99209af1..4458152dd9a98890fc3a3e7f324245ec68821467 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -43,7 +43,7 @@ class AsyncExecution { AsyncExecution(Backend* backend, std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result); - tensorflow::Status BlockUntilDone() const; + Status BlockUntilDone() const; const GlobalDataHandle& result() const { return result_; } @@ -77,7 +77,7 @@ class ExecutionTracker { GlobalDataHandle data); // Unregisters the execution for the given handle. - tensorflow::Status Unregister(const ExecutionHandle& handle); + Status Unregister(const ExecutionHandle& handle); // Resolves the given ExecutionHandle to an AsyncExecution. Returns an // error status if the given handle is not found, which means that the diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md similarity index 100% rename from tensorflow/compiler/xla/tools/parser/README.md rename to tensorflow/compiler/xla/service/g3doc/hlo_parser.md diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 1c72ca066502eb549bf8638cdf0b7827b06f92d7..020ffcd106862cb2641a9f3bceb70acdd969a458 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -36,7 +36,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); Status status = GatherExpander{}.Run(module.get()).status(); EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); @@ -63,7 +63,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); ASSERT_TRUE(changed); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index ddb687314ee8221ba9282f230db498b3a5d23499..5ee67ccb4ae147683c7b41941670c6fc413a0d09 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -89,7 +89,7 @@ GenericTransferManager::TransferLiteralFromDevice( } Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const Literal& literal, + se::StreamExecutor* executor, const LiteralSlice& literal, const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " @@ -115,7 +115,7 @@ Status GenericTransferManager::TransferLiteralToDevice( TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); // Element is array-shaped: transfer array data to device buffer. - const auto subliteral = LiteralView::Create(literal, index); + const auto subliteral = LiteralSlice(literal, index); std::unique_ptr relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), @@ -137,7 +137,7 @@ Status GenericTransferManager::TransferLiteralToDevice( } Status GenericTransferManager::TransferLiteralToInfeed( - se::StreamExecutor* executor, const Literal& literal) { + se::StreamExecutor* executor, const LiteralSlice& literal) { return Unimplemented("Generic transfer to Infeed"); } diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 0579099de40ba3e43f69e4e6474b56691064c692..3da9570ef7eebcdf618439f628fb4d5589993e4f 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -45,11 +45,11 @@ class GenericTransferManager : public TransferManager { se::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; Status TransferLiteralToDevice(se::StreamExecutor* executor, - const Literal& literal, + const LiteralSlice& literal, const ShapedBuffer& device_buffer) override; Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) override; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 7cb7f550730eeb562a6163cf5499ffaaf02d3327..5e02631a5856a45762027d246144d6f6dd913c26 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,8 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") + licenses(["notice"]) # Apache 2.0 package(default_visibility = [":friends"]) @@ -23,6 +25,11 @@ filegroup( load("//tensorflow:tensorflow.bzl", "tf_cc_test") +xla_proto_library( + name = "backend_configs", + srcs = ["backend_configs.proto"], +) + cc_library( name = "gpu_constants", srcs = ["gpu_constants.cc"], @@ -133,6 +140,7 @@ cc_library( "ir_emitter_unnested.h", ], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", @@ -156,6 +164,7 @@ cc_library( "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", @@ -266,6 +275,7 @@ cc_library( "while_thunk.h", ], deps = [ + ":backend_configs", ":buffer_allocations", ":cudnn_convolution_runner", ":infeed_manager", @@ -291,6 +301,7 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", @@ -321,6 +332,7 @@ cc_library( srcs = ["cudnn_convolution_algorithm_picker.cc"], hdrs = ["cudnn_convolution_algorithm_picker.h"], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", @@ -337,6 +349,7 @@ cc_library( srcs = ["cudnn_convolution_runner.cc"], hdrs = ["cudnn_convolution_runner.h"], deps = [ + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -388,8 +401,10 @@ cc_library( deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/service:pattern_matcher", ], ) @@ -398,10 +413,41 @@ tf_cc_test( srcs = ["instruction_fusion_test.cc"], deps = [ ":instruction_fusion", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:multi_output_fusion", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "multi_output_fusion_test", + srcs = ["multi_output_fusion_test.cc"], + deps = [ + ":multi_output_fusion", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) @@ -443,9 +489,9 @@ tf_cc_test( ":instruction_fusion", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -505,6 +551,7 @@ cc_library( ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", + ":multi_output_fusion", ":pad_insertion", ":partition_assignment", ":stream_assignment", @@ -539,6 +586,8 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_constant_sinking", + "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", @@ -584,14 +633,18 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ + ":gpu_options", ":ir_emission_utils", + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -688,6 +741,27 @@ cc_library( ], ) +cc_library( + name = "gpu_options", + srcs = ["gpu_options.cc"], + hdrs = ["gpu_options.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "stream_executor_util", + srcs = ["stream_executor_util.cc"], + hdrs = ["stream_executor_util.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + tf_cc_test( name = "gpu_hlo_support_checker_test", srcs = ["gpu_hlo_support_checker_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto new file mode 100644 index 0000000000000000000000000000000000000000..640c6392b8b820c708b853c2a3cea4d4116e85a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package xla.gpu; + +// Backend configs for XLA:GPU. +// +// These are metadata that the GPU backend attaches to HloInstrucitons and later +// uses during e.g. codegen. +// +// Remember that proto3 doesn't give clients a way to tell the difference +// between a field not being present and a field having the default value. +// Choose your defaults carefully. +// +// No guarantee is made about the stability of these protos. +// +// See HloInstruction::backend_config() for more info. + +// Backend config for a convolution that runs through cudnn. +message CudnnConvBackendConfig { + // Opaque algorithm number of cudnn algorithm chosen for this conv. + int64 algorithm = 1; + + // Whether we may use tensor cores when running this conv. Even if this is + // true, cudnn may choose not to use tensor cores, e.g. because the GPU or + // selected algorithm doesn't support it. + bool tensor_ops_enabled = 2; +} diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 837f05244f7a8c71483cc30eeac9e1c219e6bbd2..ab5149dcdb09290cd0c0b2233029d0988a95f036 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -37,11 +37,11 @@ void BufferAllocations::Builder::RegisterBuffer(BufferAllocation::Index index, } StatusOr> BufferAllocations::Builder::Build( - const BufferAssignment& buffer_assignment, int device_ordinal, + const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator) { - const int64 num_buffers = buffer_assignment.Allocations().size(); - auto buffer_allocations = WrapUnique( - new BufferAllocations(num_buffers, device_ordinal, memory_allocator)); + const int64 num_buffers = buffer_assignment->Allocations().size(); + auto buffer_allocations = WrapUnique(new BufferAllocations( + num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { // If buffer #i's address is already registered (e.g. external arguments or @@ -62,28 +62,28 @@ StatusOr> BufferAllocations::Builder::Build( // Allocate each allocation that might escape, or is the temp buffer. bool seen_temp_buffer = false; - const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) { const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; if (buffer_size > 0) { - TF_ASSIGN_OR_RETURN(buffer_address, memory_allocator->Allocate( - device_ordinal, buffer_size)); - if (buffer_address == nullptr) { - return ResourceExhausted( - "Out of memory when allocating %s for buffer %lld.", - tensorflow::strings::HumanReadableNumBytes(buffer_size).c_str(), - i); - } - if (reinterpret_cast(buffer_address.opaque()) % + OwningDeviceMemory buffer; + TF_ASSIGN_OR_RETURN( + buffer, memory_allocator->Allocate(device_ordinal, buffer_size)); + if (reinterpret_cast(buffer.opaque()) % kCudaMallocAlignBytes != 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " "multiple of %llx, but was %p", - kCudaMallocAlignBytes, buffer_address.opaque()); + kCudaMallocAlignBytes, buffer.opaque()); } + // We do manual memory management within BufferAllocations. Be sure not + // to do a TF_RETURN_IF_ERROR between this line and the + // buffer_allocations->SetBuffer(buffer_address) call below! + buffer_address = buffer.Forget(); } + buffer_allocations->SetBuffer(i, buffer_address); if (allocation.IsPreallocatedTempBuffer()) { if (seen_temp_buffer) { @@ -103,28 +103,42 @@ StatusOr> BufferAllocations::Builder::Build( << "B)"; } } - return std::move(buffer_allocations); } -tensorflow::Status BufferAllocations::TearDown( - const std::set& live_addresses, - const BufferAssignment& buffer_assignment) { - // Deallocate temporary buffers. - const int64 num_buffers = buffer_assignment.Allocations().size(); +BufferAllocations::~BufferAllocations() { + if (!torn_down_) { + // Presumably if we're executing this branch, the caller is in an error + // state, otherwise it would have explicitly called TearDown so it could + // save some set of live addresses. So ignoring any errors in TearDown is + // sensible. + TearDown(/*live_addresses=*/{}).IgnoreError(); + } +} + +Status BufferAllocations::TearDown( + const std::set& live_addresses) { + // Deallocate temporary buffers, taking care to try to deallocate all of them + // even if one of the deallocations fails. + Status status; + const int64 num_buffers = buffer_assignment_->Allocations().size(); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { - const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + const BufferAllocation& allocation = buffer_assignment_->GetAllocation(i); se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index()); // Deallocate buffers marked "maybe_live_out" but aren't actually live out, // and temp buffers. if ((allocation.maybe_live_out() && !live_addresses.count(buffer_address)) || allocation.IsPreallocatedTempBuffer()) { - TF_RETURN_IF_ERROR( - memory_allocator_->Deallocate(device_ordinal_, &buffer_address)); + auto dealloc_result = + memory_allocator_->Deallocate(device_ordinal_, buffer_address); + if (!dealloc_result.ok() && status.ok()) { + status = dealloc_result; + } } } - return tensorflow::Status::OK(); + torn_down_ = true; + return status; } se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index c2fc35be4ca4bc6df85ee21fb6b564bfd6de3ec0..636623502597b3a66523938ba430e9d5a82f796c 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -48,13 +48,15 @@ class BufferAllocations { // `device_ordinal` is the number of the device this function allocates // memory on. StatusOr> Build( - const BufferAssignment& buffer_assignment, int device_ordinal, + const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator); private: std::map registered_buffers_; }; + ~BufferAllocations(); + BufferAllocations(const BufferAllocations&) = delete; BufferAllocations& operator=(const BufferAllocations&) = delete; @@ -76,16 +78,16 @@ class BufferAllocations { // Tears down all buffers allocated by this object that are not in // `live_addresses`. - tensorflow::Status TearDown( - const std::set& live_addresses, - const BufferAssignment& buffer_assignment); + Status TearDown(const std::set& live_addresses); private: BufferAllocations(BufferAllocation::Index buffer_count, int device_ordinal, - DeviceMemoryAllocator* memory_allocator) + DeviceMemoryAllocator* memory_allocator, + const BufferAssignment* buffer_assignment) : buffers_(buffer_count), device_ordinal_(device_ordinal), - memory_allocator_(memory_allocator) {} + memory_allocator_(memory_allocator), + buffer_assignment_(buffer_assignment) {} // Sets the device address of buffer `buffer_index`. void SetBuffer(BufferAllocation::Index buffer_index, @@ -100,8 +102,9 @@ class BufferAllocations { se::DeviceMemoryBase temp_buffer_base_; int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; + const BufferAssignment* buffer_assignment_; + bool torn_down_ = false; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index dce8de2e301ecfaa4674b8be48b8c02bfabf3f4b..77a48965e031349b045a956fd3f28c58607328e5 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -35,9 +35,10 @@ ConditionalThunk::ConditionalThunk( true_thunk_(std::move(true_thunk_sequence), hlo), false_thunk_(std::move(false_thunk_sequence), hlo) {} -Status ConditionalThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable)); - TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable)); +Status ConditionalThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable, executor)); + TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable, executor)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index e40872688fdad24d24db5f7cebb3206c77652dce..ee03865d174469285a9e98b8a30fea90d997df37 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -47,7 +47,8 @@ class ConditionalThunk : public Thunk { ConditionalThunk(const ConditionalThunk&) = delete; ConditionalThunk& operator=(const ConditionalThunk&) = delete; - Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream) override; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 64d3b84b8c73d82800270aebcebf7f7a8fec5fe4..f0881124128c9b043392ffc4fa3aee2cd5b754c7 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -29,11 +29,6 @@ namespace xla { namespace gpu { using se::dnn::AlgorithmDesc; -using se::dnn::BatchDescriptor; -using se::dnn::ConvolutionDescriptor; -using se::dnn::DataLayout; -using se::dnn::FilterDescriptor; -using se::dnn::FilterLayout; ConvolutionThunk::ConvolutionThunk( CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index bf912fbd14de5874062a79db28186ab233575233..ee38c0318a878c7bcdc02afdcd146bfb4498d9a2 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -29,12 +29,12 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status HostToDeviceCopyThunk::ExecuteOnStream( +Status HostToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); stream->ThenMemcpy(&destination_data, source_address_, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( @@ -46,14 +46,14 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status DeviceToDeviceCopyThunk::ExecuteOnStream( +Status DeviceToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); se::DeviceMemoryBase source_data = buffer_allocations.GetDeviceAddress(source_buffer_); stream->ThenMemcpy(&destination_data, source_data, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 2e7eb5f3445bc9294fa455ef31fb816cdba4726c..8b128386f61636de9ac41e856a2b00c578e05735 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -39,8 +39,8 @@ class HostToDeviceCopyThunk : public Thunk { HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete; HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const void* source_address_; @@ -62,8 +62,8 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const BufferAllocation::Slice source_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 41ee45f55fafcbb96265b97f31e26b75ab96675c..3dc98c4c93ea2b9b68dd3ee27794a39847f8756c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -35,35 +36,22 @@ class ScratchAllocator : public se::ScratchAllocator { ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} - ~ScratchAllocator() override; - int64 GetMemoryLimitInBytes(se::Stream* stream) override { return 1LL << 32; // 4GB. TODO(jlebar): Tune this? } int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - se::port::StatusOr> AllocateBytes( - se::Stream* stream, int64 byte_size) override; + StatusOr> AllocateBytes(se::Stream* stream, + int64 byte_size) override; private: const int device_ordinal_; DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; -ScratchAllocator::~ScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - -se::port::StatusOr> ScratchAllocator::AllocateBytes( +StatusOr> ScratchAllocator::AllocateBytes( se::Stream* stream, int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes(stream)) { @@ -74,19 +62,14 @@ se::port::StatusOr> ScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Failed to allocate %lld bytes on device %d.", - byte_size, device_ordinal_)); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); } // Determines whether we can safely perform a winograd non-fused convolution for @@ -334,21 +317,20 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( Shape new_call_shape = ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), ShapeUtil::MakeShape(U8, {scratch_bytes})}); - HloInstruction* algorithm_hlo = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(algorithm))); - HloInstruction* tensor_ops_enabled_hlo = - computation->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(tensor_ops_enabled))); + + CudnnConvBackendConfig backend_config; + backend_config.set_algorithm(algorithm); + backend_config.set_tensor_ops_enabled(tensor_ops_enabled); HloInstruction* new_call = computation->AddInstruction(HloInstruction::CreateCustomCall( new_call_shape, - {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo, - tensor_ops_enabled_hlo}, + {instr->mutable_operand(0), instr->mutable_operand(1)}, instr->custom_call_target())); new_call->set_window(instr->window()); new_call->set_convolution_dimension_numbers( instr->convolution_dimension_numbers()); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely // (conv_result, u8[0]). diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 10b4c3de89989c52cfea5273c3d5b0beef76abd2..0645fbb3ad39f1f1649caf45a6068b5a196c30b9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -113,8 +115,17 @@ Status RunCudnnConvolution( // cuDNN's convolution APIs support the BDYX layout for activations/output and // the OIYX layout for weights. + DataLayout input_dl; + FilterLayout filter_dl; + DataLayout output_dl; + + TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), + XlaConvLayoutsToStreamExecutorLayouts( + dnums, input_shape.layout(), filter_shape.layout(), + output_shape.layout())); + BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(DataLayout::kBatchDepthYX) + input_descriptor.set_layout(input_dl) .set_feature_map_count( input_shape.dimensions(dnums.input_feature_dimension())) .set_count(input_shape.dimensions(dnums.input_batch_dimension())); @@ -126,7 +137,7 @@ Status RunCudnnConvolution( } FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(FilterLayout::kOutputInputYX) + filter_descriptor.set_layout(filter_dl) .set_input_feature_map_count( filter_shape.dimensions(dnums.kernel_input_feature_dimension())) .set_output_feature_map_count( @@ -149,7 +160,7 @@ Status RunCudnnConvolution( } BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(DataLayout::kBatchDepthYX) + output_descriptor.set_layout(output_dl) .set_feature_map_count( output_shape.dimensions(dnums.output_feature_dimension())) .set_count(output_shape.dimensions(dnums.output_batch_dimension())); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 5af7a77ea858563fbea05af8efd54f96a74aee93..b812dd7d3fbb25f279e87f79b647e299f29073ea 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -53,11 +53,17 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrAppend; +namespace { // Returns whether operand is a floating-point literal with the given value. bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { - return operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAllFloat(value); + if (operand->opcode() == HloOpcode::kConstant && + operand->literal().IsAllFloat(value)) { + return true; + } + return operand->opcode() == HloOpcode::kBroadcast && + IsFPLiteralWithValue(operand->operand(0), value); } +} // namespace GpuElementalIrEmitter::GpuElementalIrEmitter( const HloModuleConfig& hlo_module_config, llvm::Module* module, @@ -227,6 +233,11 @@ StatusOr GpuElementalIrEmitter::EmitLog( return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitLog1p( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitSin( PrimitiveType prim_type, llvm::Value* value) const { return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); @@ -242,6 +253,11 @@ StatusOr GpuElementalIrEmitter::EmitExp( return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitExpm1( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 77d4569b1e8e398005e8f517ff086a77aedd382d..91f4d960aa62fff3e0699ece37a8c74d7dcf2f59 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -64,6 +64,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const override; + StatusOr EmitSin(PrimitiveType prim_type, llvm::Value* value) const override; @@ -73,6 +76,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitExp(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const override; + StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const override; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index cc747addbd152eb82b0b2ef92b8653fc861f97be..e14ee6918bf148861ecccac99355fccf7ae93103 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -31,23 +31,12 @@ FftScratchAllocator::FftScratchAllocator( int device_ordinal, DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} -FftScratchAllocator::~FftScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default. return kFftScratchSize; } -se::port::StatusOr> FftScratchAllocator::AllocateBytes( +StatusOr> FftScratchAllocator::AllocateBytes( se::Stream* stream, int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes(stream)) { @@ -58,18 +47,14 @@ se::port::StatusOr> FftScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return tensorflow::errors::ResourceExhausted( - "Failed to allocate %lld bytes on device %d.", byte_size, - device_ordinal_); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); } namespace { @@ -121,8 +106,8 @@ FftThunk::FftThunk(FftType fft_type, input_shape_(input_shape), output_shape_(output_shape) {} -tensorflow::Status FftThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); VLOG(3) << "Output shape: " @@ -222,7 +207,7 @@ tensorflow::Status FftThunk::ExecuteOnStream( LOG(FATAL) << "unsupported fft type"; } if (launch_ok) { - return tensorflow::Status::OK(); + return Status::OK(); } return InternalError("Unable to launch fft for thunk %p with type %s", this, FftTypeToString(fft_type_).c_str()); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 24b1dca99865fe21d0ca3af91e0d169f7b74a78a..b0a22564f3a09bb67a3c01723f6e37c604656d45 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -39,8 +39,6 @@ class FftScratchAllocator : public se::ScratchAllocator { FftScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator); - ~FftScratchAllocator() override; - int64 GetMemoryLimitInBytes(se::Stream* stream) override; int64 TotalAllocatedBytes() { return total_allocated_bytes_; } @@ -51,7 +49,7 @@ class FftScratchAllocator : public se::ScratchAllocator { private: const int device_ordinal_; DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; @@ -73,8 +71,8 @@ class FftThunk : public Thunk { FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ // Does the FFT for the thunk on "stream". - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const se::fft::Type fft_type_; diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 6e6966df3987eef29b2122c3ef8f11b7cd0bfe14..b36539e0cb8d0a2f4758dd90acbdd8fc7181b8ca 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -30,19 +30,20 @@ ForThunk::ForThunk(const int64 loop_limit, body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -tensorflow::Status ForThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); - return tensorflow::Status::OK(); +Status ForThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); + return Status::OK(); } -tensorflow::Status ForThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { for (int64 i = 0; i < loop_limit_; ++i) { // Invoke loop body thunk sequence. TF_RETURN_IF_ERROR( body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index c78d1c50686297aea8235af928aba562697f49bc..41ddfe0ceb1d0516c1c64feca53212a925632209 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -36,9 +36,10 @@ class ForThunk : public Thunk { ForThunk(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete; - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const int64 loop_limit_; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 2217776c7d5a5f92c520d56222988f80401be9e4..b22bb1d39ba177ef42673c7a3755694b43c15d14 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace gpu { @@ -40,7 +40,7 @@ class FusionMergerTest : public HloTestBase {}; // Tuple // TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule MergeSharedFusionInstruction comp.3 { @@ -104,7 +104,7 @@ ENTRY MergeSharedFusionInstruction.Computation0 { // // Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio. TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule FlopsToBytesRatioThresholdExceeded comp.2 { @@ -162,7 +162,7 @@ ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 { // is merged into Fusion0 and Fusion1) would exceed the bytes transferred // threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule BytesTransferredThresholdExeceeded comp.2 { @@ -210,7 +210,7 @@ ENTRY BytesTransferredThresholdExeceeded.Computation2 { // Fusion2 is reduced for this test which makes the merge operation into its // operand below the bytes transferred threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule BytesTransferredThresholdNotExeceeded comp.2 { @@ -253,7 +253,7 @@ ENTRY BytesTransferredThresholdNotExeceeded.Computation2 { // Check that we're willing to merge f1_computation into f2_computation, even // though f2 is an input fusion node. TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule m f1_computation { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index f996fe486d1fe691899bd69dcedf3e29a963ff42..79fca43d022816645b8a07b9e806fe9cc3745e7c 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -215,6 +215,25 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { } } +DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) { + if (hlo_instruction.opcode() == HloOpcode::kDot) { + return hlo_instruction.dot_dimension_numbers(); + } + CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion); + CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput); + CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(), + HloOpcode::kMultiply); + // Try to find the dot inside the output fusion node. + const HloInstruction* dot = + hlo_instruction.fused_expression_root()->operand(0); + if (dot->opcode() != HloOpcode::kDot) { + dot = hlo_instruction.fused_expression_root()->operand(1); + } + CHECK_EQ(dot->opcode(), HloOpcode::kDot); + + return dot->dot_dimension_numbers(); +} + } // namespace GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, @@ -232,8 +251,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, output_shape_(output_shape), alpha_(alpha) {} -tensorflow::Status GemmThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(2) << "Executing a GemmThunk"; se::DeviceMemoryBase lhs_data = @@ -281,8 +300,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( shape.dimensions(!is_row_major)); }; - const DotDimensionNumbers& dim_nums = - hlo_instruction()->dot_dimension_numbers(); + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); const MatrixDescriptor lhs_descriptor = make_descriptor( lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); @@ -350,7 +368,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( if (!launch_ok) { return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index f42cbf9e9483b59f1f103b128b36263ccaf64ec5..7a4830d64e7caef5a1170cbdbf8ab373fdaf16e2 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -47,8 +47,8 @@ class GemmThunk : public Thunk { GemmThunk& operator=(const GemmThunk&) = delete; // Does the gemm operation for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; // Returns true if we'll perform autotuning if run on the given stream. If // so, we want the GPU to be quiescent during autotuning, so as not to diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 4fdc4c89618bc0f179b2332373cb2fd3cf637390..afefc740d707cd6fd01420e950eafd37abe80119 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" @@ -73,6 +74,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -128,9 +131,8 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule(HloModule* hlo_module, - se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { +Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); @@ -161,8 +163,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/false); + /*rewrite_grad_op=*/true); // Rewrite gather ops into smaller ones. pass.AddPass(); @@ -175,6 +176,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -201,18 +203,28 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("layout_assignment"); + pipeline.AddPass( + hlo_module->mutable_device_entry_computation_layout(), stream_exec); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + pipeline.AddPass>( + /*is_layout_sensitive=*/true, + /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { + return true; + }); // Choose the fastest algorithm for each conv. // - // In theory doing this here is way too early: It needs to happen after - // layout assignment, because the layout of the inputs/outputs affects the - // speed of the conv. But currently we only allow only one input/output - // layout when calling cudnn, so there's no ambiguity. - // - // We pick the algorithm at this early stage so we can generate better HLO. - // After CudnnConvolutionRewriter, our convolutions are CustomCalls which - // return a tuple (conv_result, scratch_memory), and the each conv uses 0 - // bytes of scratch: + // We pick the algorithm before fusion so we can generate better HLO. After + // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a + // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of + // scratch: // // customcall = (f32[...], f32[0]) // return gte(customcall, 0) @@ -228,35 +240,15 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, // The new tuple and gte instructions then be simplified away, because // nobody is expected to use the scratch value. // - // However, if we were to run CudnnConvolutionAlgorithmPicker after layout - // assignment, fusion would already have run, and the gte(customcall, 0) - // would probably already be into a fusion node. We can't simplify across - // HloComputation boundaries, so in this case we wouldn't be able to - // simplify away the new_tuple bits. - // - // We'll need to revisit this if we ever allow multiple layouts for the - // inputs/outputs of a cudnn convolution. + // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion + // the gte(customcall, 0) would probably already be into a fusion node. We + // can't simplify across HloComputation boundaries, so in this case we + // wouldn't be able to simplify away the new_tuple bits. pipeline.AddPass(stream_exec, device_allocator); // Clean up new_tuple described above. pipeline.AddPass(); - pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } - - { - HloPassPipeline pipeline("layout_assignment"); - pipeline.AddPass( - hlo_module->device_entry_computation_layout()); - - // The LayoutAssignment pass may leave behind kCopy instructions which are - // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, - /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { - return true; - }); pipeline.AddPass(/*is_layout_sensitive=*/true); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -267,6 +259,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); + fusion.AddPass(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); @@ -283,12 +276,21 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); } } - return tensorflow::Status::OK(); + + { + // Do an aggressive LICM pass over while loops. In particular, this hoists + // constants that were sunk by WhileLoopConstantSinking. Leaving them in + // the while loop may result in unnecessary copies. + HloPassPipeline pipeline("while-loop-licm"); + pipeline.AddPass(true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + return Status::OK(); } // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { +Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 9db85bc788bde46c890a46ce9b0902ddce3f5675..c5ccdd4a7dcec02ddab8a1f748659de41f6202d2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -78,14 +78,13 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } - } else if (IsCustomCallToDnnConvolution(*hlo)) { - // The last two arguments to a CUDNN convolution are two HLO constants for - // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (ImplementedAsLibraryCall(*hlo)) { - // For all other library calls, materialize all the operands into memory. + } else if (ImplementedAsLibraryCall(*hlo) || + hlo->opcode() == HloOpcode::kCrossReplicaSum) { + // For all other library calls and cross-replica-sum, materialize all the + // operands into memory. (Cross-replica-sum gets its constant args + // materialized even if it's not implemented as a libcall to simplify the + // implementation. It's slower, but we can constant fold away constant + // args *anyway*, so we just need to make it work.) for (int64 i = 0; i < hlo->operand_count(); ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 980cc89fa03abd874a8e0a694f2abb775c1de050..25d8f720ea4791a4c94efcad6909cd0c113fbe70 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -32,12 +32,15 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace gpu { namespace { +using tensorflow::tracing::ScopedAnnotation; + // A helper class for profiling HLO in the course of GPU program execution. // All of the profiling is guarded internally, to avoid the caller needing to // have lots of conditionals sprinkled around. @@ -134,9 +137,10 @@ Status GpuExecutable::ExecuteThunks( const BufferAllocations& buffer_allocations, bool block_host_until_done, HloExecutionProfile* hlo_execution_profile) { se::Stream* main_stream = run_options->stream(); + se::StreamExecutor* executor = main_stream->parent(); std::pair stream_compute_compatibility; - main_stream->parent()->GetDeviceDescription().cuda_compute_capability( + executor->GetDeviceDescription().cuda_compute_capability( &stream_compute_compatibility.first, &stream_compute_compatibility.second); TF_RET_CHECK(stream_compute_compatibility == compute_capability_) @@ -155,21 +159,39 @@ Status GpuExecutable::ExecuteThunks( sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); - TF_ASSIGN_OR_RETURN( - sub_streams.back(), - run_options->BorrowStream(main_stream->parent()->device_ordinal())); + TF_ASSIGN_OR_RETURN(sub_streams.back(), + run_options->BorrowStream(executor->device_ordinal())); } HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, sub_streams, hlo_module_->entry_computation()); uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - // The next event enqueued on stream N must not run until the thunk at - // last_blocking_thunk_for_stream[N] completes. - std::map last_blocking_thunk_for_stream; + // This top-level trace serves two purposes: + // 1) It marks the scope of the whole XLA module. + // 2) It tells us whether tracing is enabled. We use this to avoid the + // expensive HloInstruction::ToString() calls inside the loop below if + // tracing is disabled. + ScopedAnnotation top_level_annotation(hlo_module_->name(), "XLA GPU module"); + std::map> thunk_to_finish_event; for (Thunk* thunk : thunk_schedule_->TotalOrder()) { - TF_RETURN_IF_ERROR(thunk->Initialize(*this)); + // Annotate execution of this op if tracing was enabled when we started + // running this module. If tracing is enabled *while* we're running the + // module, we won't get any data, but that's probably an OK trade-off. + // + // TODO(jlebar): Should we cache the results of HloInstruction::ToString(), + // since we expect it to be an expensive call? + tensorflow::gtl::optional op_annotation; + if (top_level_annotation.IsEnabled()) { + op_annotation.emplace( + thunk->hlo_instruction() != nullptr + ? thunk->hlo_instruction()->ToString(HloPrintOptions::Canonical()) + : "", + "XLA op"); + } + + TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor)); int32 stream_no = thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction()); se::Stream* stream = @@ -179,18 +201,10 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } - if (last_blocking_thunk_for_stream.count(stream_no)) { - stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, - last_blocking_thunk_for_stream[stream_no]) - .get()); - last_blocking_thunk_for_stream.erase(stream_no); - } - // If this thunk requests it, wait for all currently-executing thunks to // finish. This is useful e.g. if the thunk is about to perform autotuning. if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); - last_blocking_thunk_for_stream.clear(); } profiler.StartOperation(); @@ -198,22 +212,11 @@ Status GpuExecutable::ExecuteThunks( << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); - if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) { + if (thunk_schedule_->Depended(thunk)) { auto finish_event = MakeUnique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); - - if (thunk->ShouldBlockFutureThunks()) { - // Set last_blocking_thunk_for_stream on all streams other than this one - // so that all other streams will wait for this thunk to complete before - // executing any events that occur later in the total order. - for (int32 i = 0; i < sub_streams.size() + 1; ++i) { - if (i != stream_no) { - last_blocking_thunk_for_stream[i] = thunk; - } - } - } } profiler.FinishOperation(thunk->hlo_instruction()); } @@ -286,8 +289,8 @@ StatusOr GpuExecutable::ExecuteOnStream( se::StreamExecutor* executor = run_options->stream()->parent(); TF_ASSIGN_OR_RETURN( auto buffer_allocations, - buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(), - memory_allocator)); + buffer_allocations_builder.Build( + assignment_.get(), executor->device_ordinal(), memory_allocator)); bool block_host_until_done = !memory_allocator->AllowsAsynchronousDeallocation(); @@ -329,8 +332,7 @@ StatusOr GpuExecutable::ExecuteOnStream( buffers_in_result.insert(src_base); return Status::OK(); })); - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown(buffers_in_result, *assignment_)); + TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result)); return std::move(shaped_buffer); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 89f1e625884568bf7370b3801d851ef4846c2a98..8bf62dde8b9948375fc493fd1a524cfa7b062502 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,31 +18,72 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace gpu { -// cuDNN convolutions are called with specific layouts on the input, output, -// and filter: -// -// input: DataLayout::kBatchDepthYX -// output: DataLayout::kBatchDepthYX -// filter: FilterLayout::kOutputInputYX -// -// The order dimensions in the constant name is major-to-minor (eg, the -// most-major dimension of the input is batch, most-minor is X). The -// specific dimension numbers these named dimensions correspond to is -// determined by the ConvolutionDimensionNumbers argument. Y is spatial -// dimension 0, and X is spatial dimension 1. -// -// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. -static Status AddBackendConstraintsToDnnConvCustomCall( +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::FilterLayout; + +static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { + int major, minor; + CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, + &minor)); + return major >= 7; +} + +// Returns (input, filter, output) layouts. +static std::tuple +HeuristicLayoutAssignment(const HloInstruction* instr, + stream_executor::StreamExecutor* stream_executor) { + // DataLayout and FilterLayout uses weird enum names. Translations: + // N <=> Batch or Output + // C <=> Depth or Input + // H <=> Y + // W <=> X + // + // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW. + + // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x + // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version + // changes, as well as the hardware updates. + if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 && + IsVoltaOrLater(*stream_executor))) { + return std::make_tuple(DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); + // For BackwardInput that has stride, full NHWC layouts run significantly + // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC). + // + // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW, + // NHWC). + if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && + window_util::HasStride(instr->window())) { + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth); +} + +// Adds layout constraints on the cudnn custom-call instruction. The layout +// constraints are represented in terms of minor_to_major fields of both +// operands and the output shape. Depending on the underlying algorithm, one of +// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen. +Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( HloInstruction* instr, LayoutConstraints* constraints) { CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); Shape input_shape; @@ -66,39 +107,25 @@ static Status AddBackendConstraintsToDnnConvCustomCall( << instr->custom_call_target(); } - // Construct minor-to-major dimension orders for operands and result. - // cuDNN's convolution APIs support the BDYX layout for activations/output - // and the OIYX layout for weights. - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN - // calls after we switch to cuDNN v5. - const ConvolutionDimensionNumbers& dimension_numbers = - instr->convolution_dimension_numbers(); - std::vector input_layout; - for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0; - --i) { - input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); - } - input_layout.push_back(dimension_numbers.input_feature_dimension()); - input_layout.push_back(dimension_numbers.input_batch_dimension()); - *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); - - std::vector filter_layout; - for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0; - --i) { - filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); - } - filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension()); - filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension()); - *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); - - std::vector output_layout; - for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0; - --i) { - output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); + { + DataLayout input; + FilterLayout filter; + DataLayout output; + if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); + } else { + input = DataLayout::kBatchDepthYX; + filter = FilterLayout::kOutputInputYX; + output = DataLayout::kBatchDepthYX; + } + + TF_ASSIGN_OR_RETURN( + std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(), + *output_shape.mutable_layout()), + StreamExecutorConvLayoutsToXlaLayouts( + instr->convolution_dimension_numbers(), input, filter, output)); } - output_layout.push_back(dimension_numbers.output_feature_dimension()); - output_layout.push_back(dimension_numbers.output_batch_dimension()); - *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); // The custom call returns a tuple of (actual_result, scratch_buffer); // call_result_buf is the logical buffer for actual_result, the thing that @@ -132,7 +159,13 @@ static Status AddBackendConstraintsToDnnConvCustomCall( Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { - for (auto* instruction : constraints->computation()->instructions()) { + // Add convolution constraints in reverse postorder that the earliest + // convolution layout propagates first. This reduces the likelihood of fusion + // nodes with copies. + auto post_order = constraints->computation()->MakeInstructionPostOrder(); + for (auto iterator = post_order.rbegin(); iterator != post_order.rend(); + ++iterator) { + HloInstruction* instruction = *iterator; if (IsCustomCallToDnnConvolution(*instruction)) { TF_RETURN_IF_ERROR( AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 51aae79c3d8d0000007f9d2926d245de838d3aca..ce24af1cf8856920ccf438b5bbd2ef28cfa8ba6f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { @@ -27,9 +28,10 @@ namespace gpu { // layout constraints for operands and results of library calls. class GpuLayoutAssignment : public LayoutAssignment { public: - explicit GpuLayoutAssignment( - const ComputationLayout& entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout, + se::StreamExecutor* stream_executor) + : LayoutAssignment(entry_computation_layout), + stream_executor_(stream_executor) {} ~GpuLayoutAssignment() override {} protected: @@ -42,6 +44,12 @@ class GpuLayoutAssignment : public LayoutAssignment { LayoutConstraints* constraints) override; bool CustomCallRequiresMajorFirstLayout( const HloInstruction* instruction) override; + + private: + Status AddBackendConstraintsToDnnConvCustomCall( + HloInstruction* instr, LayoutConstraints* constraints); + + se::StreamExecutor* stream_executor_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 7c801955943021def4ddc0accd9f318b7916ce93..e48165c1426ea04839c245bc20b851a0f1710246 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -69,7 +69,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape_with_layout); - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -156,7 +157,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape); } - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -225,7 +227,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { {result_shape, offset_scale_shape, offset_scale_shape})); } - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -305,7 +308,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { {result_shape, scale_shape, scale_shape})); } - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_options.cc new file mode 100644 index 0000000000000000000000000000000000000000..35b4b4e20b633792de4251a4b0e89f4b579053ce --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.cc @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { +namespace gpu { + +bool ConvUseLayoutHeuristic(const HloModuleConfig& config) { + return !config.debug_options().xla_backend_extra_options().count( + "xla_gpu_experimental_conv_disable_layout_heuristic"); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h new file mode 100644 index 0000000000000000000000000000000000000000..498d4a94955cb2c50e0b165f28ded44ac1c0bfff --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ + +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +// Helper functions for querying options that are specific to the GPU backend. + +namespace xla { +namespace gpu { + +// Returns true if we should use heuristics to assign convolution layouts, as +// opposed to always assigning NCHW. +bool ConvUseLayoutHeuristic(const HloModuleConfig& config); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index f13727ca9b6954f6be9b9277018fcc64ee326954..7bb8df6581b49b1bf8c84a972f715e8dc119d8de 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -44,8 +44,8 @@ GpuTransferManager::GpuTransferManager() /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) .getPointerSize(0 /* default address space */)) {} -Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status GpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index d040a99975230578c270deabdfe60c61649e778c..09f8227f508a3159f3def285898e15bfad544552 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -37,7 +37,7 @@ class GpuTransferManager : public GenericTransferManager { ~GpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source) override; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index 6436abc06cb9b0d69bc977334e68d91c03af2c98..45f0a1c645b2875cf90d2c11cfb66c3dd855d097 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -42,6 +42,14 @@ class HloScheduleTest : public HloTestBase { .ConsumeValueOrDie(); } + std::unique_ptr CreateNewModule() { + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_disable_multi_streaming(false); + config.set_debug_options(debug_options); + return MakeUnique("test_module", config); + } + HloVec RemoveHlo(const HloVec& input, const std::unordered_set& remove) { HloVec result(input); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index 3ddc1c0789d746bf021256638342364aac63e0e3..ae310beefad0c81c17fd4140b441b3a19a002e2c 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -49,13 +49,25 @@ void InfeedManager::EnqueueBuffers(const std::vector& buffers) { } InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { - tensorflow::mutex_lock l(mu_); - while (enqueued_buffer_.empty()) { - cv_.wait(l); + bool became_empty = false; + InfeedBuffer* current_buffer; + { + tensorflow::mutex_lock l(mu_); + while (enqueued_buffer_.empty()) { + cv_.wait(l); + } + current_buffer = enqueued_buffer_.front(); + enqueued_buffer_.pop_front(); + dequeued_buffer_.insert(current_buffer); + if (enqueued_buffer_.empty()) { + became_empty = true; + } + } + if (became_empty) { + for (const auto& callback : on_empty_callbacks_) { + callback(); + } } - InfeedBuffer* current_buffer = enqueued_buffer_.front(); - enqueued_buffer_.pop_front(); - dequeued_buffer_.insert(current_buffer); return current_buffer; } @@ -88,6 +100,10 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { return host_to_device_stream_.get(); } +void InfeedManager::RegisterOnEmptyCallback(std::function callback) { + on_empty_callbacks_.push_back(std::move(callback)); +} + InfeedManager* GetOrCreateInfeedManager() { static InfeedManager* manager = new InfeedManager; return manager; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index d5f2216d460a45085536b15f9bf6e3bd3579f9c8..a3fc15cfe36a490f38daabca9ff36fbb1012aead 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -21,6 +21,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ #include +#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -100,6 +101,10 @@ class InfeedManager { // returns null. se::Stream* GetStream(se::StreamExecutor* executor); + // Registers a callback that will be called when 'enqueued_buffer_' becomes + // empty. + void RegisterOnEmptyCallback(std::function callback); + private: // TODO(b/30467474): Revisit if this mutex becomes a point of // contention. @@ -122,6 +127,10 @@ class InfeedManager { // Executor that the host_to_device_stream belongs to. Not owned. se::StreamExecutor* host_to_device_executor_; + + // List of callbacks which will be called when 'enqueued_buffer_' becomes + // empty. + std::vector> on_empty_callbacks_; }; // Singleton creator-or-accessor: Returns the GPU infeed manager. diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index c5eb7211859c8fcb728d28ba432b7e65979a194a..6c4519185b34989eb53c884ba214d69b824b113c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -46,6 +48,15 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } +bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { + if (constant->opcode() != HloOpcode::kConstant || + !ShapeUtil::IsScalar(constant->shape())) { + return false; + } + auto type = constant->shape().element_type(); + return type == F16 || type == F32 || type == F64; +} + } // namespace /*static*/ bool GpuInstructionFusion::IsExpensive( @@ -66,34 +77,72 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha - if (producer->opcode() == HloOpcode::kDot) { - if (consumer->opcode() == HloOpcode::kMultiply) { - CHECK_EQ(consumer->operand_count(), 2); - int64 other_operand_index = 1 - operand_index; + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { + int64 other_operand_index = 1 - operand_index; + HloInstruction* op1 = nullptr; + HloInstruction* op2 = nullptr; + if (consumer->operand_count() == 1 && + consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && + Match(consumer->fused_expression_root(), + match::Op() + .WithOpcode(HloOpcode::kMultiply) + .WithOperand(0, match::Op(&op1)) + .WithOperand(1, match::Op(&op2)))) { + CHECK(op1 != nullptr && op2 != nullptr); + // If 'consumer' is a fusion node, it should consist of a broadcast of a + // scalar constant fused into a multiply, but nothing more. So one operand + // should be a parameter, and the other should be a broadcast. + if (op1->opcode() != HloOpcode::kParameter) { + std::swap(op1, op2); + } + if (op1->opcode() != HloOpcode::kParameter || + op2->opcode() != HloOpcode::kBroadcast) { + return false; + } + if (IsIEEEFloatingPointScalarConstant(op2->operand(0))) { + return true; + } + } else if (consumer->operand_count() == 2 && + consumer->opcode() == HloOpcode::kMultiply) { const HloInstruction* alpha = consumer->operand(other_operand_index); - if (alpha->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalar(alpha->shape())) { + // Fuse if 'alpha' is a broadcast of a scalar constant. + if (alpha->opcode() == HloOpcode::kBroadcast && + alpha->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(alpha->operand(0))) { return true; } } } - // Only allow to fuse transpose into an output fusion. + // Only allow fusing transpose or broadcast into an output fusion that is + // implemented as a Gemm call. if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) { - if (producer->opcode() != HloOpcode::kTranspose) { - return false; - } - // Check that the transpose is the operand of a dot. + consumer->fusion_kind() == HloInstruction::FusionKind::kOutput && + ImplementedAsGemm(*consumer)) { auto producer_operand_index = consumer->operand_index(producer); auto fused_parameter = consumer->fused_parameter(producer_operand_index); const std::vector& fused_parameter_users = fused_parameter->users(); - return (fused_parameter_users.size() == 1 && - fused_parameter_users[0]->opcode() == HloOpcode::kDot); + if (fused_parameter_users.size() != 1) { + return false; + } + if (producer->opcode() == HloOpcode::kTranspose) { + // Check that the transpose is an operand of a dot. + return fused_parameter_users[0]->opcode() == HloOpcode::kDot; + } + if (producer->opcode() == HloOpcode::kBroadcast) { + // Check that the broadcast is a broadcast of a scalar constant into a + // multiply. + return producer->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(producer->operand(0)) && + fused_parameter_users[0]->opcode() == HloOpcode::kMultiply; + } } - // Output fusion is not currently supported on GPUs. + // Other output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { return false; } @@ -125,16 +174,46 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Fuse scalar constants into loop fusion nodes, this reduces the number of + // parameters and makes matching scalar broadcasts easier. + if (ShapeUtil::IsEffectiveScalar(producer->shape()) && + consumer->opcode() == HloOpcode::kFusion && + producer->opcode() == HloOpcode::kConstant) { + return true; + } + return IsFusile(*producer) && IsFusile(*consumer) && InstructionFusion::ShouldFuse(consumer, operand_index); } +bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) { + const HloInstruction* producer = consumer->operand(operand_index); + // The IR emitter has limited support for non-loop fusions with multi output + // at present. + // TODO(tjoerg): Relax this constraint to allow for arbitraty kinds of fusion. + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() != HloInstruction::FusionKind::kLoop) { + return false; + } + // Multi-output fusion requires instructions with compatible shapes. + if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) { + return false; + } + // TODO(tjoerg): Stop calling `ShouldFuse` to relax the criteria for + // multi-output fusion. In particular, do not check whether an instruction is + // expensive to duplicate, since this doesn't matter here. + return GpuInstructionFusion::ShouldFuse(consumer, operand_index); +} + HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { if (IsReductionToVector(*consumer)) { return HloInstruction::FusionKind::kInput; } - if (producer->opcode() == HloOpcode::kDot) { + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { return HloInstruction::FusionKind::kOutput; } if (HloOpcode::kFusion == consumer->opcode()) { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index 9fb06b0a244186484b1c17edf13bd28a4305a1a6..f629d9ff2c7165b652369612c30979150f93bd24 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -31,6 +31,9 @@ class GpuInstructionFusion : public InstructionFusion { bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; + bool ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) override; + HloInstruction::FusionKind ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 6c9a805ad637ceef71f6bb021154f358e6e02825..1963d9eef72d41fa0a275bea98f959671fa7e737 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -108,8 +111,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -125,8 +128,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); @@ -140,7 +143,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { // Tests that broadcasts fused into a fusion with a reduce root. TEST_F(InstructionFusionTest, BroadcastIntoReduce) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module add { @@ -165,11 +168,11 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Fusion()); EXPECT_THAT(root->fused_expression_root(), - op::Reduce(op::Broadcast(op::Parameter()), op::Parameter())); + op::Reduce(op::Broadcast(op::Constant()), op::Constant())); } TEST_F(InstructionFusionTest, BitcastIntoAdd) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY BroadcastIntoAdd { @@ -191,7 +194,7 @@ TEST_F(InstructionFusionTest, BitcastIntoAdd) { } TEST_F(InstructionFusionTest, AddIntoBitcast) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY BroadcastIntoAdd { @@ -213,7 +216,7 @@ TEST_F(InstructionFusionTest, AddIntoBitcast) { } TEST_F(InstructionFusionTest, DontFuseGTE) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY DontFuseGTE { p0 = (f32[10], f32[10]) parameter(0) @@ -229,15 +232,16 @@ TEST_F(InstructionFusionTest, DontFuseGTE) { } TEST_F(InstructionFusionTest, DotOutputFusion) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { - constant = f32[] constant(3) + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} p0 = f32[4,3]{1,0} parameter(0) p1 = f32[4,3]{1,0} parameter(1) transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0} - dot = f32[4,4]{1,0} dot(p0, transpose) - ROOT mul = f32[4,4] multiply(constant, dot) + dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT mul = f32[4,4] multiply(dot, broadcast) })") .ValueOrDie(); @@ -247,16 +251,17 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput); EXPECT_THAT( root->fused_expression_root(), - op::Multiply(op::Parameter(), - op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); + op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), + op::Broadcast(op::Constant()))); } // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is // duplicated and fused into both reduces. TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module Add { lhs = f32[] parameter(0) @@ -279,14 +284,15 @@ TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { .ValueOrDie()); HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())); + EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())) + << module->ToString(); } // Compute sum(100/p0), where p0 has type s32, twice. Check that the division // is *not* duplicated and fused into both reduces, because we say that integer // division is not cheap. TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module Add { lhs = s32[] parameter(0) @@ -306,7 +312,298 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()); + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY NoOutputFusion { + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[3,4]{1,0} parameter(1) + dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + d = f32[4,4]{1,0} multiply(dot, dot) + ROOT mul = f32[4,4] multiply(d, broadcast) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); + EXPECT_THAT(root->fused_expression_root(), + op::Multiply(op::Multiply(op::Parameter(), op::Parameter()), + op::Broadcast(op::Constant()))); +} + +// Counts the HLO ops with a given op code in the specified module. +static int Count(const HloModule& module, HloOpcode op) { + int count = 0; + for (const auto* computation : module.computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == op) { + ++count; + } + } + } + return count; +} + +// Returns an HLO instruction from the given computation with the op code. +static StatusOr FindHloInstruction( + const HloComputation& computation, HloOpcode op) { + for (const auto* instruction : computation.instructions()) { + if (instruction->opcode() == op) { + return instruction; + } + } + return NotFound( + "Computation '%s' does not contain an instruction with op code '%s'.", + computation.name().c_str(), HloOpcodeString(op).c_str()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion) { + // sub --> add --> tuple + // \---------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + sub = f32[4,3]{1,0} subtract(p0, p2) + add = f32[4,3]{1,0} add(sub, p1) + ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT( + fusion->fused_expression_root(), + op::Tuple(op::Add(op::Subtract(), op::Parameter()), op::Subtract())); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) { + // tanh --> add --> tuple + // \---------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + tanh = f32[4,3]{1,0} tanh(p0) + add = f32[4,3]{1,0} add(tanh, p1) + ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(tanh, add) + })") + .ValueOrDie(); + + // TODO(tjoerg): Allow multi-output fusion for expensive operations like tanh. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion2) { + // sub --> add1 --\--------\ + // \----------> add2 --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + sub = f32[4,3]{1,0} subtract(p0, p2) + add1 = f32[4,3]{1,0} add(sub, p1) + add2 = f32[4,3]{1,0} add(sub, add1) + ROOT tuple = (f32[4,3]{1,0}) tuple(add1, add2) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Add(op::Subtract(), op::Add()), + op::Add(op::Subtract(), op::Parameter()))); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion3) { + // sub --> add1 ----\--------\ + // \ --> add2 --> add3 --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + p3 = f32[4,3]{1,0} parameter(3) + sub = f32[4,3]{1,0} subtract(p0, p2) + add1 = f32[4,3]{1,0} add(sub, p1) + add2 = f32[4,3]{1,0} add(p2, sub) + add3 = f32[4,3]{1,0} add(add1, add2) + ROOT tuple = (f32[4,3]{1,0}) tuple(add3, add2) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Add(op::Add(), op::Add()), + op::Add(op::Parameter(), op::Subtract()))); +} + +TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) { + // sub --> mul ---\ + // \--> call --> add --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + c = f32[] constant(42) + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + sub = f32[4,3]{1,0} subtract(p0, p1) + mul = f32[4,3]{1,0} multiply(sub, c) + call = f32[4,3]{1,0} custom-call(sub), custom_call_target="foo" + add = f32[4,3]{1,0} add(mul, call) + ROOT tuple = (f32[4,3]{1,0}) tuple(add) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + // Visit instructions in post order to detect cycles. + // TODO(tjoerg): Add cycle detection to the HloVerifier. + class DummyVisitor : public DfsHloVisitorWithDefault { + public: + DummyVisitor() {} + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + } visitor; + for (const HloComputation* computation : module->MakeComputationPostOrder()) { + // Accept will return a FailedPrecondition when a cycle is detected. + EXPECT_TRUE(computation->root_instruction()->Accept(&visitor).ok()); + } +} + +TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) { + // sub[2,3] --> add[4,3] --> tuple([2,3], [4,3]) + // \-------------------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[2,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[2,3]{1,0} parameter(2) + sub = f32[2,3]{1,0} subtract(p0, p2) + add = f32[4,3]{1,0} add(sub, p1) + ROOT tuple = (f32[2,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) + })") + .ValueOrDie(); + + // Multi-output fusion requires shapes to be compatible. Since `sub` and `add` + // have incompatible shapes, expect that no multi-output fusion happens. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { + auto module = ParseHloString(R"( + HloModule test_module + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + fused_computation { + p1 = f32[10] parameter(0) + zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(p1, zero), dimensions={0}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[10] parameter(0) + mul = f32[10] multiply(p0, p0) + fusion = f32[] fusion(mul), kind=kInput, calls=fused_computation + ROOT tuple = (f32[10], f32[]) tuple(fusion, mul) + })") + .ValueOrDie(); + + // Multi-output fusion is not supported for non-loop fusions at present. Since + // `fused_computation` is a input fusion, expect no multi-output fusion to + // happen. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseScalarConstant) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY FuseScalarConstant { + p0 = f32[] parameter(0) + c0 = f32[] constant(1) + add1 = f32[] add(p0, c0) + b0 = f32[2]{0} broadcast(add1), dimensions={} + c1 = f32[2]{0} constant({1, 2}) + ROOT add2 = f32[2]{0} add(b0, c1) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Add(op::Broadcast(op::Add(op::Parameter(), op::Constant())), + op::Parameter())); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 96199035b9e6d39332861079e16b5a4d20eee1a8..67890bfed1136796c83c7ef6912ffc1ab1b7e332 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -59,6 +59,25 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, !ShapeUtil::HasZeroElements(lhs_shape) && !ShapeUtil::HasZeroElements(rhs_shape); } + +bool DotImplementedAsGemm(const HloInstruction& dot) { + CHECK_EQ(dot.opcode(), HloOpcode::kDot); + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); + return true; + } + return false; +} } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { @@ -69,20 +88,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); - CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), - rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); - return true; - } + return DotImplementedAsGemm(hlo); } if (hlo.opcode() == HloOpcode::kFusion && @@ -94,7 +100,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { dot = hlo.fused_expression_root()->operand(1); } if (dot->opcode() == HloOpcode::kDot) { - return ImplementedAsGemm(*dot); + return DotImplementedAsGemm(*dot); } } @@ -156,19 +162,8 @@ static HloInstruction* CreateCudnnConv( Shape call_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - // Our CustomCall takes four arguments: The conv lhs and rhs, the cudnn - // algorithm to use, and a boolean indicating whether to use tensor cores. - // - // It's up to a later pass to choose the algorithm and decide whether to use - // tensor cores, so to indicate that we haven't yet made a choice, we speicfy - // -1 and false for those args. - HloInstruction* negative_one = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))); - HloInstruction* false_constant = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); - HloInstruction* custom_call = - computation->AddInstruction(HloInstruction::CreateCustomCall( - call_shape, {lhs, rhs, negative_one, false_constant}, call_target)); + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); custom_call->set_window(window); custom_call->set_convolution_dimension_numbers(dnums); return custom_call; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 1e0db2821a2c212d0f212ae94ab69231bc6053ea..547af33e9a98c03e1429366172f9a401e385a9d1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -94,7 +94,10 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) { << std::endl << " its type: " << llvm_ir::DumpToString(*global_for_const->getType()); - bindings_.BindHloToIrValue(*constant, global_for_const); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global_for_const, + llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); + bindings_.BindHloToIrValue(*constant, shape_constant); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index b0accc08d479258d65a18202122e4c9e90ff78d0..e55dfc6dae844ceb1d28ad389d133c80823bad9a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -120,10 +120,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* GetBasePointer(const HloInstruction& inst) const { return bindings_.GetBasePointer(inst); } - // A convenient helper for calling BufferAssignment::GetUniqueTopLevelSlice. - BufferAllocation::Slice GetAllocationSlice(const HloInstruction& hlo) const { + // A convenient helper for calling BufferAssignment::GetUniqueSlice. + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { return ir_emitter_context_->buffer_assignment() - .GetUniqueTopLevelSlice(&hlo) + .GetUniqueSlice(&hlo, index) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 71aada080ae8df70bffce3e1854b5fbd833efd23..bb47a4280541ce2806472aa9365bb0ef38c0c3b3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/core/lib/core/status.h" @@ -116,6 +117,26 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { Status IrEmitterNested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { + // For MOF we give the loop emitter an array for every output it should + // generate. + if (hlo.IsMultiOutputFusion()) { + std::vector target_arrays; + for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e; + ++i) { + target_arrays.push_back(GetIrArray(hlo, hlo, {i})); + } + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, target_arrays, &ir_builder_) + .EmitLoop()); + + std::vector tuple_operand_ptrs; + for (const llvm_ir::IrArray& array : target_arrays) { + tuple_operand_ptrs.push_back(array.GetBasePointer()); + } + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, + module_); + return Status::OK(); + } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &ir_builder_) .EmitLoop(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 83d90296df8ec75c29c537a90c6292e4f4f0e0ae..726434c3dfd4f1ef866d2eb9b6d7eb8b659e0984 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -58,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" @@ -79,6 +81,7 @@ namespace { using llvm_ir::IrName; using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::InlinedVector; using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; using tensorflow::strings::StrCat; @@ -267,7 +270,10 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) { // Find the largest possible power of two to unroll by. // TODO(kramerb): Make this smarter. - int64 num_elements = ShapeUtil::ElementsIn(hlo->shape()); + const Shape& element_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); + int64 num_elements = ShapeUtil::ElementsIn(element_shape); for (int i = max_unroll_factor; i > 1; i /= 2) { if (num_elements % i == 0) { return i; @@ -419,15 +425,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - const HloInstruction* algorithm_inst = custom_call->operand(2); - CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); - int64 algorithm = algorithm_inst->literal().Get({}); - - const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3); - CHECK(tensor_ops_enabled_inst->IsConstant()) - << tensor_ops_enabled_inst->ToString(); - bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get({}); - + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + custom_call->backend_config()); const auto& target = custom_call->custom_call_target(); std::unique_ptr thunk; if (target == kCudnnConvForwardCallTarget) { @@ -442,7 +441,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardInput, @@ -455,7 +455,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardFilter, @@ -468,7 +469,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); @@ -496,12 +498,31 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // initializes the output array to the initial value of the reduce. if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { switch (root->opcode()) { + case HloOpcode::kTuple: case HloOpcode::kReduce: { VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); - TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, - BuildInitializerThunk(fusion)); std::vector> thunks; - thunks.push_back(std::move(initializer_thunk)); + ArraySlice output_instructions = + root->opcode() == HloOpcode::kTuple + ? root->operands() + : ArraySlice(&root, 1); + + // For multi-output fusion emit an initializer for each tuple element. + // Otherwise it's sufficient to just initialize the single output. + HloInstruction* first_reduce = nullptr; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() == HloOpcode::kReduce) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr initializer_thunk, + BuildInitializerThunk(fusion, output_instructions[i] == root + ? ShapeIndex() + : ShapeIndex({i}))); + thunks.push_back(std::move(initializer_thunk)); + first_reduce = + first_reduce == nullptr ? output_instructions[i] : first_reduce; + } + } + CHECK(first_reduce != nullptr); thunks.push_back(BuildKernelThunk(fusion)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), fusion)); @@ -515,11 +536,50 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - Shape input_shape = root->operand(0)->shape(); - return EmitReductionToVector( - root, input_shape, fused_emitter.GetGenerator(root->operand(0)), - fused_emitter.GetGenerator(root->operand(1)), root->dimensions(), - root->to_apply()); + // For multi-output fusion CHECK the constraints and feed all the + // reduces into a single loop code generator. Single-output reduce + // fusion is a special case of that. + InlinedVector input_gens; + InlinedVector init_value_gens; + std::vector> + extra_output_gens; + InlinedVector reducers; + InlinedVector reduce_output_shapes; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + const HloInstruction* inst = output_instructions[i]; + ShapeIndex output_shape_index; + if (root->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + // TODO(kramerb): CHECK that layouts are equal. Currently this + // breaks multioutputfusion_test. The test has pre-fused + // instructions, but layout_assignment will not assign any layouts + // for instructions inside of a fused computation. It just removes + // the layouts instead. + if (inst->opcode() == HloOpcode::kReduce) { + CHECK(ShapeUtil::Compatible(first_reduce->shape(), inst->shape())); + CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + CHECK(ShapeUtil::Compatible(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); + CHECK(first_reduce->dimensions() == inst->dimensions()); + input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); + init_value_gens.push_back( + fused_emitter.GetGenerator(inst->operand(1))); + reducers.push_back(inst->to_apply()); + reduce_output_shapes.push_back(std::move(output_shape_index)); + } else { + CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(), + inst->shape())); + extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), + std::move(output_shape_index)); + } + } + const Shape& input_shape = first_reduce->operand(0)->shape(); + return EmitReductionToVector(first_reduce, input_shape, input_gens, + init_value_gens, + first_reduce->dimensions(), reducers, + reduce_output_shapes, extra_output_gens); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -565,12 +625,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - int unroll_factor = 1; - // TODO(kramerb): Unrolling multi-output loop fusions too. - if (!fusion->IsMultiOutputFusion()) { - CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); - unroll_factor = ComputeMaxUnrollFactor(fusion); - } + CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); + int unroll_factor = ComputeMaxUnrollFactor(fusion); thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); return IrEmitter::HandleFusion(fusion); @@ -908,10 +964,33 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { return IrEmitter::HandleCopy(copy); } +Status IrEmitterUnnested::EmitExtraOutputsForReduce( + const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { + for (int i = 0; i != extra_output_gens.size(); ++i) { + const HloInstruction* output = reduce->parent()->FusionInstruction(); + llvm::Value* extra_output_address = + GetIrArray(*output, *output, extra_output_gens[i].second) + .EmitArrayElementAddress(index, &ir_builder_, + "extra_output_element_address"); + TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, + extra_output_gens[i].first(index)); + ir_builder_.CreateStore(extra_output_ir_value, extra_output_address); + } + return Status::OK(); +} + Status IrEmitterUnnested::EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // Number of elements processed by a single thread. constexpr int64 kTileSize = 16; int64 num_elems = ShapeUtil::ElementsIn(input_shape); @@ -963,14 +1042,19 @@ Status IrEmitterUnnested::EmitReductionToScalar( // auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; @@ -1003,11 +1087,16 @@ Status IrEmitterUnnested::EmitReductionToScalar( llvm_ir::IrArray::Index input_index( /*linear=*/x, input_shape, &ir_builder_); llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); - return (EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens); }; // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's @@ -1042,20 +1131,24 @@ Status IrEmitterUnnested::EmitReductionToScalar( : element_ir_type; for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_address, - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, result_from_other_lane}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } } const HloInstruction* output = @@ -1071,14 +1164,21 @@ Status IrEmitterUnnested::EmitReductionToScalar( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(/*linear=*/ir_builder_.getInt64(0), - output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + /*linear=*/ir_builder_.getInt64(0), + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // Emit a parallel loop that iterates through all input tiles, one per thread. @@ -1098,8 +1198,13 @@ Status IrEmitterUnnested::EmitReductionToScalar( Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // Divide the input matrix into tiles of size Kx1. For example, when the // input matrix is 4x4 and K=2, the tiled matrix looks like // @@ -1109,9 +1214,13 @@ Status IrEmitterUnnested::EmitColumnReduction( // 4567 // Numbers indicate tile IDs. // // Each tile is first partially reduced to a scalar by a thread, and then the - // scalar is accumulated to the output vector using atomic operations. We - // choose 16 as the tile size, which matches Eigen's ColumnReduceKernel. - constexpr int64 kTileSize = 16; + // scalar is accumulated to the output vector using atomic operations. + // + // We choose 128 as the tile size based on empirical evidence. It's big enough + // to reduce the amount of atomic adds in the end, maximizing the memory + // bandwidth. + constexpr int64 kTileSize = 128; + // If the height is not a multiple of the tile size, we pad the bottom of the // input matrix. const int64 height_in_tiles = CeilOfRatio(height, kTileSize); @@ -1141,15 +1250,20 @@ Status IrEmitterUnnested::EmitColumnReduction( // } auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } // Emit an inner for-loop that partially reduces the elements in the given @@ -1207,13 +1321,18 @@ Status IrEmitterUnnested::EmitColumnReduction( .SourceIndexOfTranspose(normalized_input_shape, input_shape, transpose_dimension_mapping, &ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, - input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); } - return (EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address)); }; // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's @@ -1242,13 +1361,20 @@ Status IrEmitterUnnested::EmitColumnReduction( &ir_builder_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + x, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // Emit a parallel loop that iterate through all input tiles. @@ -1266,12 +1392,42 @@ Status IrEmitterUnnested::EmitColumnReduction( .EmitLoop(IrName(reduce)); } +static std::pair ComputeTilingSchemeForReduction( + int64 depth, int64 width, int64 kWarpSize) { + constexpr int64 kTargetNumElementsPerThread = 64; + int64 x_tile_size = kTargetNumElementsPerThread; + int64 z_tile_size = 1; + + // Only tile along the x dimension with tile size kTargetNumElementsPerThread + // if doing so doesn't require a slow version of loop with bound check on each + // dimension. A more sophisticated heuristics is to enable tile along the + // x dimension with tile size kTargetNumElementsPerThread when either width is + // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big + // enough so that only a small fraction of the threads execute the slow + // version of loop with bound check. + if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) { + x_tile_size = 8; + z_tile_size = 8; + while (depth % z_tile_size != 0) { + z_tile_size -= 1; + } + } + + return std::pair(x_tile_size, z_tile_size); +} + Status IrEmitterUnnested::EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // A naive algorithm is: - // 1. Divide the input tensor into tiles of size 1x1xK. + // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. // 2. Partially reduces each tile to a scalar using one thread. // 3. Accumulates that scalar to the output vector using atomic operations. // @@ -1282,15 +1438,15 @@ Status IrEmitterUnnested::EmitRowReduction( // int y = linear_index / width_in_tiles % height; // int z = linear_index / (height * width_in_tiles); // float partial_result = 0; - // for (element_id_in_tile : range(kTileSize)) { - // int x = x_in_tiles * kTileSize + element_id_in_tile; + // for (element_id_in_tile : range(x_tile_size)) { + // int x = x_in_tiles * x_tile_size + element_id_in_tile; // if (x < width) // partial_result = reducer(partial_result, input[z][y][z]); // } // AtomicReducer(&output[y], partial_result); // } // - // Three optimizations are performed. + // Four optimizations are performed. // // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead @@ -1317,29 +1473,44 @@ Status IrEmitterUnnested::EmitRowReduction( // element_id_in_tile, which makes the code more friendly to optimizations // such as LICM. // + // 4. When the width is too small and x_tile_size is less than the target + // number of elements per thread and use a small factor of depth as + // z_tile_size to increase the number of elements calculated by each + // partial sum. This can reduce the needed number of dynamic shfl_down and + // atomic operations. + // // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; // linear_index < depth * height * width_in_tiles; // linear_index += blockDim.x * gridDim.x) { // int x_in_tiles = linear_index % width_in_tiles; // int y = linear_index / width_in_tiles % height; - // int z = linear_index / (height * width_in_tiles); + // int z_in_tiles = linear_index / (height * width_in_tiles); // int warp_id = x_in_tiles / warpSize; // int lane_id = x_in_tiles % warpSize; // float partial_result = 0; // int x = warp_id * kTileSize * warpSize + lane_id; - // if (width % (kTileSize * warpSize) == 0 || - // x + (kTileSize - 1) * warpSize < width) { - // // The entire tile is in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; - // ++element_id_in_tile, x += warpSize) { - // partial_result = Reducer(partial_result, input[z][y][x]); + // if (width % (x_tile_size * warpSize) == 0 || + // x + (x_tile_size - 1) * warpSize < width) { + // // The entire x_tile is in bounds. + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // for (int element_id_in_x_tile = 0;element_id_in_x_tile < x_tile_size; + // ++element_id_in_x_tile, x += warpSize) { + // partial_result = Reducer(partial_result, input[z][y][x]); + // } // } // } else { // // The tile is partially in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // for (int element_id_in_x_tile = 0; element_id_in_x_tile < + // x_tile_size; // ++element_id_in_tile, x += warpSize) { - // if (x < width) - // partial_result = Reducer(partial_result, input[z][y][x]); + // if (x < width) + // partial_result = Reducer(partial_result, input[z][y][x]); + // } // } // } // for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2) @@ -1350,29 +1521,35 @@ Status IrEmitterUnnested::EmitRowReduction( // AtomicReducer(&output[y], partial_result); // } // - // Choose 8 as the tile size, which matches Eigen's RowReduceKernel. - constexpr int64 kTileSize = 8; + + int64 x_tile_size; + int64 z_tile_size; + std::tie(x_tile_size, z_tile_size) = + ComputeTilingSchemeForReduction(depth, width, kWarpSize); + // Round the width in tiles up to the nearest multiple of kWarpSize, so that // the use of shfl_down is valid. const int64 width_in_tiles = - RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize); + RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { - // Emit the loop body that reduces one tile. + auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) { + // Emit the loop body that reduces one z-x-tile. + const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( input_shape.element_type(), ir_emitter_context_->llvm_module()); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } - // Emit an inner for-loop that partially reduces the elements in the given - // tile. - llvm::Value* z = tile_index[0]; + llvm::Value* z_tile = tile_index[0]; llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; llvm::Value* warp_id = ir_builder_.CreateUDiv( @@ -1380,102 +1557,132 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* lane_id = ir_builder_.CreateURem( x_tile, ir_builder_.getInt64(kWarpSize), "lane_id"); - // The x-location of the last element in this tile. - // last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize); + // The x-location of the last element in this z-x-tile. + // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * + // x_tile_size); llvm::Value* last_x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize - 1), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); - - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize); - llvm::Value* x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - tile_element_loop->GetIndVarValue(), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); - - // Unless we know the tile is entirely in bounds, we have to emit a - // x-in-bounds check before reading from the input. - if (!tile_in_bounds) { - llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)), - "x_in_bounds", &ir_builder_); - - // Points ir_builder_ to the then-block. - llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &ir_builder_); - } + lane_id, ir_builder_.CreateNSWMul( + ir_builder_.getInt64(kWarpSize), + ir_builder_.CreateNSWAdd( + ir_builder_.getInt64(x_tile_size - 1), + ir_builder_.CreateNSWMul( + warp_id, ir_builder_.getInt64(x_tile_size))))); + + KernelSupportLibrary ksl( + &ir_builder_, + /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, + /*prevent_vectorization=*/false); + + // Emit a for-loop that partially reduces the elements in the given + // z-x-tile. + auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, + int64 x_tile_loop_bound) -> Status { + auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { + llvm::Value* z = ir_builder_.CreateNSWAdd( + z_indvar, ir_builder_.CreateNSWMul( + ir_builder_.getInt64(z_tile_size), z_tile)); + + TF_RETURN_IF_ERROR(ksl.For( + "x_tile", + /*start=*/0, /*end=*/x_tile_loop_bound, /*step=*/1, + [&](llvm::Value* x_indvar) -> Status { + // x = lane_id + warpSize * (element_id_in_x_tile + warp_id * + // x_tile_size); + llvm::Value* x = ir_builder_.CreateNSWAdd( + lane_id, + ir_builder_.CreateNSWMul( + ir_builder_.getInt64(kWarpSize), + ir_builder_.CreateNSWAdd( + x_indvar, + ir_builder_.CreateNSWMul( + warp_id, ir_builder_.getInt64(x_tile_size))))); + + // Unless we know the x-tile is entirely in bounds, we have to + // emit a x-in-bounds check before reading from the input. + if (!x_tile_in_bounds) { + llvm_ir::LlvmIfData if_x_in_bounds_data = + llvm_ir::EmitIfThenElse(ir_builder_.CreateICmpULT( + x, ir_builder_.getInt64(width)), + "x_in_bounds", &ir_builder_); + // Points ir_builder_ to the then-block. + llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, + &ir_builder_); + } + + // Emit code that reads the input element and accumulates it + // to the partial reduction result. + llvm::Value* input_address = + ir_builder_.CreateAlloca(element_ir_type); + { + // {z,y,x} is an index to input_3d_tensor_shape + // [depth,height,width]. We need to convert that to an index + // to input_shape (the shape of the operand of "reduce"). + // This conversion is composed of a transposition from + // input_shape to normalized_input_shape and a reshape from + // normalized_input_shape to input_3d_tensor_shape. + const Shape normalized_input_shape = ShapeUtil:: + MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = + LayoutUtil::MinorToMajor(input_shape); + const std::vector transpose_dimension_mapping( + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); + const Shape input_3d_tensor_shape = + ShapeUtil::MakeShapeWithDescendingLayout( + input_shape.element_type(), {depth, height, width}); + const llvm_ir::IrArray::Index input_3d_tensor_index( + {z, y, x}, input_3d_tensor_shape, &ir_builder_); + const llvm_ir::IrArray::Index input_index = + input_3d_tensor_index + .SourceIndexOfReshape(input_3d_tensor_shape, + normalized_input_shape, + &ir_builder_) + .SourceIndexOfTranspose( + normalized_input_shape, input_shape, + transpose_dimension_mapping, &ir_builder_); + + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); + } + })); + return Status::OK(); + }; - // Emit code that reads the input element and accumulates it to the - // partial reduction result. - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - { - // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We - // need to convert that to an index to input_shape (the shape of the - // operand of "reduce"). This conversion is composed of a transposition - // from input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_3d_tensor_shape. - const Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), - {depth, height, width}); - const llvm_ir::IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = - input_3d_tensor_index - .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, &ir_builder_) - .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, - &ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, - input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); - } - return EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address); + return ksl.For("z_tile", + /*start=*/0, /*end=*/z_tile_size, /*step=*/1, + emit_z_tile_element_loop); }; llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0), + ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0), ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width))); - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - // After the if-then-else statement on tile_in_bounds, emit calls to - // shfl_down that accumulate the partial reduction results of all threads - // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + TF_RETURN_IF_ERROR( + ksl.If(tile_in_bounds, + /*true_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true, + x_tile_size); + }, + /*false_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop( + /*x_tile_in_bounds=*/false, + CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize)); + })); + + // After accumulating the elements of the z_x_tile, emit calls to + // shfl_down that accumulate the partial reduction results of all + // threads in a warp. int bit_width = llvm_ir::GetSizeInBits(element_ir_type); // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. @@ -1484,20 +1691,24 @@ Status IrEmitterUnnested::EmitRowReduction( : element_ir_type; for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_address, - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, result_from_other_lane}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } } const HloInstruction* output = @@ -1511,19 +1722,34 @@ Status IrEmitterUnnested::EmitRowReduction( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + y, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + if (x_tile_size * z_tile_size < depth * width) { + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i])); + } else { + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {output_address, partial_reduction_result_addresses[i]}, + output_address)); + } + } + return Status::OK(); }; // Emit a parallel loop that iterates through every input tiles. Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {depth, height, width_in_tiles}, - {2, 1, 0}); + reduce->shape().element_type(), + {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); @@ -1544,10 +1770,14 @@ Status IrEmitterUnnested::EmitRowReduction( // elementwise. Status IrEmitterUnnested::EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer) { + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // This emission requires "reduce" to have an input layout. It is either set // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for // a fused kReduce). @@ -1582,8 +1812,9 @@ Status IrEmitterUnnested::EmitReductionToVector( // `EmitReductionToVector`, we only need to check whether the minormost // dimension of the input is to keep. if (input_dims_to_keep.empty()) { - return EmitReductionToScalar(reduce, input_shape, input_gen, init_value_gen, - reducer); + return EmitReductionToScalar(reduce, input_shape, input_gens, + init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); } else if (input_dims_to_keep.front() == LayoutUtil::Minor(input_shape.layout(), 0)) { // Column reduction. Treat the result of "input" as a matrix whose width @@ -1600,8 +1831,9 @@ Status IrEmitterUnnested::EmitReductionToVector( height *= input_shape.dimensions(input_dim); } } - return EmitColumnReduction(height, width, reduce, input_shape, input_gen, - init_value_gen, reducer); + return EmitColumnReduction(height, width, reduce, input_shape, input_gens, + init_value_gens, reducers, reduce_output_shapes, + extra_output_gens); } else { // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a // 3D tensor. The size of dimension 1 (the height) is the size of the @@ -1627,7 +1859,8 @@ Status IrEmitterUnnested::EmitReductionToVector( } const int64 height = ShapeUtil::ElementsIn(reduce->shape()); return EmitRowReduction(depth, height, width, reduce, input_shape, - input_gen, init_value_gen, reducer); + input_gens, init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); } } @@ -1651,16 +1884,15 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { MakeUnique(std::move(thunks), reduce)); return EmitReductionToVector( - reduce, input->shape(), - [&](const llvm_ir::IrArray::Index& index) { + reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) { return GetIrArray(*input, *reduce) .EmitReadArrayElement(index, &ir_builder_); - }, - [&](const llvm_ir::IrArray::Index& index) { + }}, + {[&](const llvm_ir::IrArray::Index& index) { return GetIrArray(*init_value, *reduce) .EmitReadArrayElement(index, &ir_builder_); - }, - dimensions_to_reduce, reducer); + }}, + dimensions_to_reduce, {reducer}, {{}}, {}); } thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); @@ -1928,6 +2160,52 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } +Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { + if (hlo_module_config_.replica_count() != 1) { + // TODO(b/33011107): Support nontrivial cross replica sum on GPU. + return Unimplemented( + "CrossReplicaSum with >1 replica is not implemented on GPU."); + } + + // CRS with one operand and one replica is simply the identity function. + // Buffer assignment expects a copy, so that's what we do. + // + // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely + // in algebraic-simplifier, but currently on some platforms + // HloModuleConfig::num_replicas changes between when the module is compiled + // and when it's run. + if (crs->operand_count() == 1) { + CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) + << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + thunk_sequence_->push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*crs->operand(0)), + /*destination_buffer=*/GetAllocationSlice(*crs), + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); + return Status::OK(); + } + + // One-replica CRS with multiple operands produces a tuple of the inputs. + // Again, buffer assignment expects us to copy each. + std::vector> thunks; + std::vector tuple_element_buffers; + for (int64 i = 0; i < crs->operand_count(); ++i) { + tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(crs, {i}) + .ValueOrDie()); + thunks.push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*crs->operand(i)), + /*destination_buffer=*/tuple_element_buffers.back(), + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs)); + } + + // Output a tuple of the buffers above. + thunks.push_back(MakeUnique(tuple_element_buffers, + GetAllocationSlice(*crs), crs)); + thunk_sequence_->push_back( + MakeUnique(std::move(thunks), crs)); + return Status::OK(); +} + Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { thunk_sequence_->emplace_back(BuildInfeedThunk(infeed)); return Status::OK(); @@ -2194,6 +2472,21 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( /*destination_buffer=*/GetAllocationSlice(*inst), inst); } +namespace { +double GetScalarConstantAsDouble(const Literal& literal) { + switch (literal.shape().element_type()) { + case F16: + return static_cast(literal.Get({})); + case F32: + return literal.Get({}); + case F64: + return literal.Get({}); + default: + LOG(FATAL) << "Unsupported type."; + } +} +} // namespace + std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* inst) { if (inst->opcode() == HloOpcode::kDot) { @@ -2218,6 +2511,19 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( if (dot->opcode() != HloOpcode::kDot) { std::swap(dot, alpha); } + if (alpha->opcode() == HloOpcode::kBroadcast) { + alpha = alpha->operand(0); + } + if (alpha->opcode() == HloOpcode::kParameter) { + alpha = inst->operand(alpha->parameter_number()); + } + // TODO(b/74185543): Remove the following if block once we support fusion + // with a non-constant as well. Then we will just always use the constant + // on the device. + if (alpha->opcode() == HloOpcode::kCopy) { + alpha = alpha->operand(0); + } + DCHECK(dot->opcode() == HloOpcode::kDot); const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); @@ -2229,13 +2535,13 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( inst->operand(rhs_parameter->parameter_number()); return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*mul), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - alpha->literal().Get({0}), // alpha. + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + GetScalarConstantAsDouble(alpha->literal()), // alpha. inst); } @@ -2253,21 +2559,30 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( } StatusOr> IrEmitterUnnested::BuildInitializerThunk( - const HloInstruction* hlo) { + const HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - const HloInstruction* init_value = [&] { + const HloInstruction* init_value_operand = [&] { switch (inst->opcode()) { case HloOpcode::kSelectAndScatter: return inst->operand(2); case HloOpcode::kReduce: return inst->operand(1); + case HloOpcode::kTuple: + CHECK(hlo->IsMultiOutputFusion()) + << ": " << hlo->ToString() << " is not a multi-output fusion."; + CHECK(inst->operand(index.back())->opcode() == HloOpcode::kReduce) + << ": Found '" << inst->operand(index.back())->opcode() << "' in " + << inst->ToString() << " but expected 'reduce'."; + // For multi-output fusion look through the tuple. + return inst->operand(index.back())->operand(1); default: LOG(FATAL) << "Opcode " << inst->opcode() << " should not need an initializer."; } }(); + const HloInstruction* init_value = init_value_operand; if (fused && init_value->opcode() == HloOpcode::kParameter) { init_value = hlo->operand(init_value->parameter_number()); } @@ -2285,7 +2600,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( ArraySlice literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return {MakeUnique(GetAllocationSlice(*hlo), hlo)}; + return {MakeUnique(GetAllocationSlice(*hlo, index), hlo)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -2301,8 +2616,8 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( pattern16 = literal_bytes.front(); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); - return {MakeUnique(pattern32, - GetAllocationSlice(*hlo), hlo)}; + return {MakeUnique( + pattern32, GetAllocationSlice(*hlo, index), hlo)}; } // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit @@ -2312,20 +2627,31 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( literal_bytes.size() - 4) == 0) { uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); - return {MakeUnique(word, GetAllocationSlice(*hlo), - hlo)}; + return {MakeUnique( + word, GetAllocationSlice(*hlo, index), hlo)}; } } // Otherwise fall back to our slow initializer code. std::unique_ptr kernel_thunk = BuildKernelThunk(hlo); - TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( - *hlo, - [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &ir_builder_); - }, - kernel_thunk.get())); + LaunchDimensions launch_dimensions = + CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index), + ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), + ir_emitter_context_->llvm_module()); + // If the init_value was fused into this reduce we have to generate it first. + if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { + CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); + TF_RETURN_IF_ERROR(HandleConstant(const_cast(init_value))); + } + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*init_value, *hlo) + .EmitReadArrayElement(index, &ir_builder_); + }, + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &ir_builder_) + .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally // done in IrEmitterUnnested::Postprocess().) @@ -2512,16 +2838,14 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( .EmitLoop(IrName(&hlo)); } - CHECK_EQ(unroll_factor, 1) - << "multi-output fusion does not support unrolling"; - // For multiple outputs fusion, we need to emit each operand and the root. std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, - launch_dimensions, &ir_builder_) + launch_dimensions, &ir_builder_, + unroll_factor) .EmitLoop(IrName(&hlo))); std::vector tuple_operand_ptrs; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b41ab2162ab81f66e123a7055ca3ffc815c3ef88..202231b82f3877c11cf932bd00a8aac350fd0afa 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -76,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleInfeed(HloInstruction* xla_infeed) override; Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; Status EmitTargetElementLoop( const HloInstruction& hlo, @@ -99,6 +100,13 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& inst, tensorflow::gtl::ArraySlice args); + // Helper for writing extra outputs from inside a reduce kernel. + Status EmitExtraOutputsForReduce( + const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); + // EmitColumnReduction and EmitRowReduction emit code for column and row // reduction of a matrix and/or 3D tensor. Row and column reduction have // different memory access pattern, so for performance their implementations @@ -109,28 +117,43 @@ class IrEmitterUnnested : public IrEmitter { // `EmitReductionToVector`. Note that input shape might not be // [height x width], but can be bitcast to [height x weight] with "height" // being the major dimension. - Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitColumnReduction( + int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Emits code that reduces a 3D tensor of shape [depth x height x width] to a // vector of shape [height]. Other parameters have the same meaning as those // of `EmitReductionToVector`. Note that input shape might not be // [depth x height x width], but can be bitcast to [depth x height x weight] // with "depth" being the most major dimension. - Status EmitRowReduction(int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitRowReduction( + int64 depth, int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitReductionToScalar( + HloInstruction* reduce, const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Figures out whether `reduce` is a row or column reduction, and which // dimensions to reduce, and calls either `EmitRowReduction` or @@ -140,13 +163,24 @@ class IrEmitterUnnested : public IrEmitter { // generate elements of the input and the initial value. Other parameters mean // the same as for `HandleReduce`. // + // Multiple reduces can be emitted in the same loop, assuming they have the + // same input and output shapes, and the same reduce dimensions. + // + // extra_output_gens can contain extra generators for intermediate outputs. + // These must have the same shape as the reduce input as they are computed + // when the reduce inputs are being read. + // // Prerequisite: `IsReductionToVector(*reduce)` Status EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer); + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned @@ -165,7 +199,7 @@ class IrEmitterUnnested : public IrEmitter { // Returns a thunk that, given a reduce or select-and-scatter op, initializes // its memory to the appropriate initial value. StatusOr> BuildInitializerThunk( - const HloInstruction* hlo); + const HloInstruction* hlo, const ShapeIndex& index = {}); // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index d376ef7a245eb9ed86939f44c611b6dde5606b23..f56c1ce69f11ed79c8be76834269f29de93a9645 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -35,26 +35,38 @@ KernelThunk::KernelThunk( kernel_name_(kernel_name), unroll_factor_(unroll_factor) {} -tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { +Status KernelThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { tensorflow::mutex_lock lock(mutex_); - if (loader_spec_) { - // Already initialized by another thread. - return tensorflow::Status::OK(); - } + if (!loader_spec_) { + loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); + tensorflow::StringPiece ptx = executable.ptx(); + // Convert tensorflow::StringPiece to se::port::StringPiece because + // StreamExecutor uses the latter. + loader_spec_->AddCudaPtxInMemory( + se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); - loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - tensorflow::StringPiece ptx = executable.ptx(); - // Convert tensorflow::StringPiece to se::port::StringPiece because - // StreamExecutor uses the latter. - loader_spec_->AddCudaPtxInMemory( - se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); + if (!executable.cubin().empty()) { + loader_spec_->AddCudaCubinInMemory( + reinterpret_cast(executable.cubin().data()), + kernel_name_); + } + } - if (!executable.cubin().empty()) { - loader_spec_->AddCudaCubinInMemory( - reinterpret_cast(executable.cubin().data()), kernel_name_); + // Load the kernel into the device if necessary. + // + // We could alternatively do this within ExecuteOnStream, but doing it here + // lets the time spent loading the kernel not count towards our execution + // profiles. + auto it = kernel_cache_.find(executor); + if (kernel_cache_.end() == it) { + it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; + if (!executor->GetKernel(*loader_spec_, &it->second)) { + return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + } } - return tensorflow::Status::OK(); + return Status::OK(); } void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { @@ -62,21 +74,18 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { launch_dimensions_ = launch_dims; } -tensorflow::Status KernelThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); LaunchDimensions launch_dimensions; const se::KernelBase* kernel = nullptr; + { tensorflow::mutex_lock lock(mutex_); auto it = kernel_cache_.find(executor); - if (kernel_cache_.end() == it) { - it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; - if (!executor->GetKernel(*loader_spec_, &it->second)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); - } - } + CHECK(it != kernel_cache_.end()) + << "Initialize() not called for StreamExecutor " << executor; launch_dimensions = launch_dimensions_; kernel = &it->second; } @@ -97,7 +106,7 @@ tensorflow::Status KernelThunk::ExecuteOnStream( *kernel_args)) { return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index b556befe66b6bebba1a958f553f0a9b2c4eebbe4..7def27e189b66747569344a3dbe5c0c446f903be 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -57,11 +57,12 @@ class KernelThunk : public Thunk { int unroll_factor() const { return unroll_factor_; } void SetLaunchDimensions(const LaunchDimensions& launch_dims); - tensorflow::Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; // Executes the kernel for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // Buffers passed to the kernel as arguments. @@ -83,7 +84,8 @@ class KernelThunk : public Thunk { mutable tensorflow::mutex mutex_; std::unique_ptr loader_spec_ GUARDED_BY(mutex_); - // Loaded kernels for each `StreamExecutor` + // Loaded kernels for each `StreamExecutor`. Requires pointer stability of + // values. std::unordered_map kernel_cache_ GUARDED_BY(mutex_); }; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 86c4ac18b0501c38aaaae5a007bddcf261ca338f..7de8f9e1ee922bdbf65fd1299702482e1843f17e 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -47,7 +47,6 @@ cc_library( "@llvm//:scalar", "@llvm//:support", "@llvm//:target", - "@llvm//:transform_utils", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index d70cb07c57d48c0faed2cdc5ea9fc5ce5fb32be0..a4e4e85bf3d2c197cfc691b7fca0920aa6571729 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -77,8 +77,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Since CUDA 9.0, all GPU versions are included in a single file const char* unified_libdevice_filename = "libdevice.10.bc"; std::vector unified_libdevice_files; - const tensorflow::Status status = - tensorflow::Env::Default()->GetMatchingPaths( + const Status status = tensorflow::Env::Default()->GetMatchingPaths( tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename), &unified_libdevice_files); if (status.ok() && unified_libdevice_files.size() == 1) { @@ -273,7 +272,7 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); - target_machine->addPassesToEmitFile(codegen_passes, pstream, + target_machine->addPassesToEmitFile(codegen_passes, pstream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile); codegen_passes.run(*module); } @@ -311,11 +310,11 @@ bool CouldNeedLibdevice(const llvm::Module& module) { } // Links libdevice into the given module if the module needs libdevice. -tensorflow::Status LinkLibdeviceIfNecessary( - llvm::Module* module, std::pair compute_capability, - const string& libdevice_dir_path) { +Status LinkLibdeviceIfNecessary(llvm::Module* module, + std::pair compute_capability, + const string& libdevice_dir_path) { if (!CouldNeedLibdevice(*module)) { - return tensorflow::Status::OK(); + return Status::OK(); } llvm::Linker linker(*module); @@ -336,7 +335,7 @@ tensorflow::Status LinkLibdeviceIfNecessary( return tensorflow::errors::Internal(tensorflow::strings::StrCat( "Error linking libdevice from ", libdevice_path)); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr CompileModuleToPtx(llvm::Module* module, diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..942c25453371c49a35c9f5148b6028e6a87d688d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -0,0 +1,122 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +GpuMultiOutputFusion::GpuMultiOutputFusion() : MultiOutputFusion(INT64_MAX) {} + +bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) { + auto get_element_shape = [&](HloInstruction* instr) { + const HloInstruction* element_instr = instr; + if (instr->opcode() == HloOpcode::kFusion) { + auto fused_expression_root = instr->fused_expression_root(); + if (instr->IsMultiOutputFusion()) { + // The shapes in all tuple operands should agree. Just pick the first + // one. + element_instr = fused_expression_root->operands()[0]; + } else { + element_instr = fused_expression_root; + } + } + // Special handling of kReduce instructions -- the fusion + // applies to the first operand. + if (element_instr->opcode() == HloOpcode::kReduce) { + return element_instr->operand(0)->shape(); + } + return element_instr->shape(); + }; + + // The elementwise output shapes must be the same (including layout) + return ShapeUtil::Equal(get_element_shape(instr1), get_element_shape(instr2)); +} + +bool GpuMultiOutputFusion::IsProfitableOperand(HloInstruction* instr) { + // kConstant instruction will not have memory reads, so it won't be a profit + // source. Skip them. + if (instr->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(instr->shape())) { + return false; + } + // We don't target to fuse producer/consumer instructions -- this should + // be taken care of by the instruction_fusion pass. If instr has only + // one user, it will not have sibling instructions. We won't consider it. + if (instr->user_count() < 2) { + return false; + } + return true; +} + +namespace { +bool IsReduction(HloInstruction* instr) { + if (instr->IsMultiOutputFusion()) { + for (const HloInstruction* operand : + instr->fused_expression_root()->operands()) { + if (operand->opcode() == HloOpcode::kReduce) { + return true; + } + } + return false; + } else if (instr->opcode() == HloOpcode::kFusion) { + return instr->fused_expression_root()->opcode() == HloOpcode::kReduce; + } else { + return instr->opcode() == HloOpcode::kReduce; + } +} +} // namespace + +bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { + return IsReduction(instr); +} + +int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, + HloInstruction* instr2) { + tensorflow::gtl::FlatSet in_list; + for (auto instr : instr1->operands()) { + if (!IsProfitableOperand(instr)) { + continue; + } + in_list.insert(instr); + } + int64 profit = 0; + for (auto instr : instr2->operands()) { + if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) { + continue; + } + profit += ShapeUtil::ByteSizeOf(instr->shape()); + } + VLOG(2) << "Fusing instr1=" << instr1->name() << " instr2=" << instr2->name() + << ", the profit is =" << profit; + return profit; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..5451a93cec5e4aeca05717c181edb9dad0305c83 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +namespace xla { +namespace gpu { + +// Multi-output fusion of sibling and producer-consumer instructions for the +// Jellyfish backend. +class GpuMultiOutputFusion : public MultiOutputFusion { + public: + GpuMultiOutputFusion(); + + protected: + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) override; + + // We currently only consider reduce and reduce fusion nodes as candidates. + bool IsFusible(HloInstruction* instr) override; + + // This function estimates the amount of memory reads saved by merging + // instr1 and instr2 into one multi-output fusion instruction. For a fusion + // instruction, all the operands need to be loaded from memory. If we merge + // instr1 and instr2, common operands will not be loaded twice. The profit is + // estimated as the size of the common operands b/w instr1 and instr2. + int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) override; + + // Whether fusing the instruction can reduce memory reads. + // + // TODO(tjoerg): Move this method up into the MultiOutputFusion base class. + bool IsProfitableOperand(HloInstruction* instr) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5170cbc7e3755c2d5b57b3dad4ed0a12556d65af --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -0,0 +1,171 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace gpu { + +using InstructionFusionTest = HloTestBase; + +const char kModulePrefix[] = R"( + HloModule test_module + + scalar_add_computation { + scalar_lhs = f32[] parameter(0) + scalar_rhs = f32[] parameter(1) + ROOT add = f32[] add(scalar_lhs, scalar_rhs) + } + scalar_mul_computation { + scalar_lhs = f32[] parameter(0) + scalar_rhs = f32[] parameter(1) + ROOT mul = f32[] add(scalar_lhs, scalar_rhs) + })"; + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { + // Fusion with reduce instruction root and a sibling reduce instruction + // sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] constant(1) + fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation + reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[6400]{0} parameter(1) + mul = f32[6400]{0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[6400]{0} parameter(1) + r1 = f32[64,100]{0,1} reshape(p1.2) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[6400]{0} parameter(1) + const.2 = f32[] constant(1) + fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) { + // Two sibling fusions with reduce instruction roots sharing the same input + // param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, + MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { + // Multi-output fusion with two reduce instructions root and a sibling reduce + // instruction sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) { + const.1 = f32[] constant(1) + p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0) + mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1) + reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2) + } + + ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + const = f32[] constant(1) + fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation + get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0 + get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1 + reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Reduce())); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 7bda4e2fcd469bd430e5ef1846251c8504225383..c8f0d4185c63c5bafca6f30acab31cbe8e987277 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -370,26 +370,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return true; } -StatusOr PadInsertion::Run(HloModule* module) { +StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { bool changed = false; - for (HloInstruction* instruction : - module->entry_computation()->MakeInstructionPostOrder()) { - if (IsCustomCallToDnnConvolution(*instruction)) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(instr); + } + } + for (HloInstruction* instruction : convs) { + const auto& target = instruction->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + changed |= CanonicalizeBackwardFilterConvolution(instruction); + } else if (target == kCudnnConvBackwardInputCallTarget) { + changed |= CanonicalizeBackwardInputConvolution(instruction); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instruction->ToString(); } } return changed; } +StatusOr PadInsertion::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index 5e1c68701daa02eba64f3e34933ce373a496c1b8..67e51509e4c717951c83c7e41943af1de762dee0 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -31,6 +31,7 @@ class PadInsertion : public HloPassInterface { StatusOr Run(HloModule* module) override; private: + StatusOr RunOnComputation(HloComputation* computation); // Returns if any changes are made to the parent computation. bool CanonicalizeForwardConvolution(HloInstruction* conv); bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv); diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index c8510808f10a731af90154447bd3e1e037db6348..88cb10883e97ae663dc492ad088e6daf9133d7f5 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -20,24 +20,24 @@ limitations under the License. namespace xla { namespace gpu { -SequentialThunk::SequentialThunk(std::vector>&& thunks, +SequentialThunk::SequentialThunk(std::vector> thunks, const HloInstruction* hlo) : Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {} -tensorflow::Status SequentialThunk::Initialize( - const GpuExecutable& executable) { +Status SequentialThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { for (auto& thunk : thunks_) { - TF_RETURN_IF_ERROR(thunk->Initialize(executable)); + TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor)); } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status SequentialThunk::ExecuteOnStream( +Status SequentialThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { for (const auto& thunk : thunks_) { TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index df17b8d67b80321c7088243eae46e7a723b4ede9..135f79e413dfaa27f2f2264e0daa3beb3c305e0f 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -31,16 +31,17 @@ namespace gpu { // require multiple kernel launches or library calls. class SequentialThunk : public Thunk { public: - SequentialThunk(std::vector>&& thunks, + SequentialThunk(std::vector> thunks, const HloInstruction* hlo); SequentialThunk(const SequentialThunk&) = delete; SequentialThunk& operator=(const SequentialThunk&) = delete; const std::vector>& thunks() const { return thunks_; } - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // The list of sub-thunks. diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index b42767dfd500bd87ad5bd88c3f39072058b18673..6f4bb0580e8dfc1dce1cca0a60cc3dd9ea600fb3 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -28,6 +28,14 @@ namespace gpu { class StreamAssignmentTest : public HloTestBase { protected: + std::unique_ptr CreateNewModule() { + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_disable_multi_streaming(false); + config.set_debug_options(debug_options); + return MakeUnique("test_module", config); + } + // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); }; diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..a50ddf6ac63c7fa7ccace94bc7f40f438aedccf8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" + +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { +namespace gpu { + +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::DataLayoutString; +using stream_executor::dnn::FilterLayout; +using stream_executor::dnn::FilterLayoutString; + +StatusOr> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + DataLayout input, FilterLayout filter, + DataLayout output) { + std::vector input_layout; + switch (input) { + case DataLayout::kBatchDepthYX: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.push_back(dnums.input_feature_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + input_layout.push_back(dnums.input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid input layout: ", + DataLayoutString(input)); + } + + std::vector filter_layout; + switch (filter) { + case FilterLayout::kOutputInputYX: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + break; + case FilterLayout::kOutputYXInput: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid filter layout: ", + FilterLayoutString(filter)); + } + + std::vector output_layout; + switch (output) { + case DataLayout::kBatchDepthYX: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.push_back(dnums.output_feature_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + output_layout.push_back(dnums.output_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid output layout: ", + DataLayoutString(output)); + } + + return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(output_layout)); +} + +StatusOr> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output) { + Layout nchw_input, nchw_filter, nchw_output; + std::tie(nchw_input, nchw_filter, nchw_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX) + .ConsumeValueOrDie(); + + Layout nhwc_input, nhwc_filter, nhwc_output; + std::tie(nhwc_input, nhwc_filter, nhwc_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth) + .ConsumeValueOrDie(); + + DataLayout input_layout; + if (LayoutUtil::Equal(input, nchw_input)) { + input_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(input, nhwc_input)) { + input_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid input layout: ", + input.ShortDebugString()); + } + + FilterLayout filter_layout; + if (LayoutUtil::Equal(filter, nchw_filter)) { + filter_layout = FilterLayout::kOutputInputYX; + } else if (LayoutUtil::Equal(filter, nhwc_filter)) { + filter_layout = FilterLayout::kOutputYXInput; + } else { + return tensorflow::errors::Internal("Invalid filter layout: ", + filter.ShortDebugString()); + } + + DataLayout output_layout; + if (LayoutUtil::Equal(output, nchw_output)) { + output_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(output, nhwc_output)) { + output_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid output layout: ", + output.ShortDebugString()); + } + + return std::make_tuple(input_layout, filter_layout, output_layout); +} +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h new file mode 100644 index 0000000000000000000000000000000000000000..8218f4fd11d3978d0ecc53fc15e287aea4b69ec3 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ + +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +// Helper functions for interacting with StreamExecutor. + +namespace xla { +namespace gpu { + +// Returns (input, filter, output) XLA Layout protos given the StreamExecutor +// layouts. +StatusOr> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + stream_executor::dnn::DataLayout input, + stream_executor::dnn::FilterLayout filter, + stream_executor::dnn::DataLayout output); + +// Returns (input, filter, output) StreamExecutor layouts given the XLA layouts. +StatusOr> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index a0c785ed913109e987d058124c8ef49019c98500..931c0bffab850362dbd2df975657dd47d9cbd3ae 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -70,11 +70,14 @@ class Thunk { Kind kind() const { return kind_; } const HloInstruction* hlo_instruction() const { return hlo_instruction_; } - // Prepares for executing the thunk. This method is called only once over - // Thunk's lifetime. For example, KernelThunk::Initialize loads the PTX of a - // kernel, which is the same in every execution. - virtual tensorflow::Status Initialize(const GpuExecutable& executable) { - return tensorflow::Status::OK(); + // Prepares the thunk for execution on the given StreamExecutor. + // + // This may be called multiple times. Its main purpose is to give us a chance + // to do initialization outside of ExecuteOnStream() so that the + // time spent initializing doesn't count towards our execution profile. + virtual Status Initialize(const GpuExecutable& /*executable*/, + se::StreamExecutor* /*executor*/) { + return Status::OK(); } // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) @@ -89,21 +92,13 @@ class Thunk { return false; } - // Indicates whether thunks scheduled after this one should wait for this one - // to complete before running. For example, a convolution thunk creates a - // scratch allocator, then kicks off a convolution in cudnn via the stream - // executor. When the stream executor call returns, the scratch allocator goes - // out of scope, and the scratch memory is deallocated. In this case, the - // convolution thunk needs to return true so that future thunks wait for the - // convolution thunk to avoid reusing the deallocated memory until the - // convolution thunk is done with it. - virtual bool ShouldBlockFutureThunks() { return false; } - // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. Stream argument must be non-null. - virtual tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) = 0; + // + // Precondition: Initialize(stream->parent()) has been called. + virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) = 0; private: Kind kind_; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index ecb54857ccc40ead21e5a18d79a37b545680021d..97cb04c38fbf18e516857f5269c984696ca204c3 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -20,8 +20,8 @@ limitations under the License. namespace xla { namespace gpu { -tensorflow::Status TupleThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { std::vector tuple_element_buffer_addresses; for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { tuple_element_buffer_addresses.push_back( @@ -40,7 +40,7 @@ tensorflow::Status TupleThunk::ExecuteOnStream( tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), sizeof(void*) * tuple_element_buffer_addresses.size()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 8b459c29a136a6e7853e68a1bead7d12c0d08ad0..951f809b51937c97a6e7de0345ec58a8b66a4242 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -45,8 +45,8 @@ class TupleThunk : public Thunk { TupleThunk(const TupleThunk&) = delete; TupleThunk& operator=(const TupleThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index a9f3d619a3ffd6d849572355e2902375e43508fa..30b9640c4c75dae61e9a90da5fb10e9d4a90cd26 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -34,9 +34,11 @@ WhileThunk::WhileThunk( body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -Status WhileThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable)); - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); +Status WhileThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR( + condition_thunk_sequence_->Initialize(executable, executor)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index e589ca78a7ea00e7592d6e09ead9c270f902702f..22176685a92df9c95b10f755b209309843c0fa3a 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -45,7 +45,8 @@ class WhileThunk : public Thunk { WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; - Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream) override; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index e6caec8625f0d622dbb92bcc20802d254fe23f94..7749201cbceece216a2db2569936949eb7de5125 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -144,7 +144,7 @@ class ExprTree { TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), tagged_instructions)); } - return tensorflow::Status::OK(); + return Status::OK(); } private: @@ -169,7 +169,7 @@ class MatcherBase { // Attempts to match each ExprTree in 'expr_trees_'. // Returns OK on the first successful match, error status otherwise. - virtual tensorflow::Status Run() { + virtual Status Run() { Status status; for (const ExprTree& expr_tree : expr_trees_) { status = MatchExprTree(expr_tree); @@ -201,7 +201,7 @@ class MatcherBase { } else if (type == S64) { *const_value = literal.GetFirstElement(); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr GetTaggedInstruction( @@ -315,7 +315,7 @@ class WhileConditionComputationMatcher : public MatcherBase { gte_fusion_param0->name().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; @@ -379,7 +379,7 @@ class WhileInitOperandMatcher : public MatcherBase { GetTaggedInstruction("loop_start", tagged_instructions)); TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - return tensorflow::Status::OK(); + return Status::OK(); } const HloInstruction* while_hlo_; @@ -457,8 +457,8 @@ class WhileBodyComputationMatcher : public MatcherBase { return InvalidArgument("Unexpected tuple index instruction : %s", inst->name().c_str()); } else if (tag == "loop_increment") { - // Parse the constant which represents the loop induction variable - // increment value. + // ParseHloString the constant which represents the loop induction + // variable increment value. TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); } else if (tag == "param0" && inst != computation_->parameter_instruction(0)) { @@ -477,7 +477,7 @@ class WhileBodyComputationMatcher : public MatcherBase { } } } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 3dd4c4a0794e5c41b877078c4e69c6c9584ce6c0..06a5e0351b63270b61b998ca2211f480f256f759 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -32,7 +31,7 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); const HloComputation* entry_computation = module.entry_computation(); const std::vector& instruction_sequence = @@ -47,7 +46,7 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, /*module_sequence=*/nullptr); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, @@ -73,11 +72,11 @@ Status HeapSimulator::RunComputation( // 'used_buffers' is the reverse map - it tracks which buffers were used by an // instruction, so that we can remove the instructions from a buffer's live // set after they are visited. - FlatMap> live_buffers; - FlatMap> used_buffers; + FlatMap> live_buffers; + FlatMap> used_buffers; auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( const HloInstruction* user, - const LogicalBuffer* buffer) { + const BufferValue* buffer) { if (!IgnoreBuffer(buffer)) { VLOG(4) << " Adding user " << user->name() << " to buffer " << buffer->ToString(); @@ -96,7 +95,7 @@ Status HeapSimulator::RunComputation( const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); for (const HloInstruction* user : instruction->users()) { if (user->opcode() != HloOpcode::kGetTupleElement) { - for (const LogicalBuffer* buffer : buffer_set) { + for (const BufferValue* buffer : buffer_set) { add_user_to_buffer(user, buffer); } } else { @@ -104,12 +103,12 @@ Status HeapSimulator::RunComputation( // alive. It only needs the buffers that relate to the element its // extracting, and the tuple it's extracting from, but not the buffers // for the other elements. - for (const LogicalBuffer* buffer : points_to.element({})) { + for (const BufferValue* buffer : points_to.element({})) { add_user_to_buffer(user, buffer); } const PointsToSet& gte_points_to = points_to_analysis.GetPointsToSet(user); - for (const LogicalBuffer* buffer : gte_points_to.CreateFlattenedSet()) { + for (const BufferValue* buffer : gte_points_to.CreateFlattenedSet()) { add_user_to_buffer(user, buffer); } } @@ -117,24 +116,25 @@ Status HeapSimulator::RunComputation( } const HloInstruction* root = computation.root_instruction(); - auto output_source_buffers = - points_to_analysis.GetPointsToSet(root).CreateFlattenedSet(); + BufferValueCompactPointerSet output_source_buffers = + ToBufferValueCompactPointerSet( + points_to_analysis.GetPointsToSet(root).CreateFlattenedSet()); - std::vector dead_buffers_to_free; - std::vector operand_buffers_to_free; + std::vector dead_buffers_to_free; + std::vector operand_buffers_to_free; for (const HloInstruction* instruction : instruction_sequence) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); VLOG(3) << "Instruction: " << instruction->ToString(); - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { VLOG(4) << " Defines: " << buffer->ToString() << (IgnoreBuffer(buffer) ? " (Ignored)" : ""); } dead_buffers_to_free.clear(); - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } @@ -161,7 +161,7 @@ Status HeapSimulator::RunComputation( // have no instructions left to visit are moved from live_buffers to // operand_buffers_to_free. operand_buffers_to_free.clear(); - for (const LogicalBuffer* operand_buffer : used_buffers[instruction]) { + for (const BufferValue* operand_buffer : used_buffers[instruction]) { if (IgnoreBuffer(operand_buffer)) { continue; } @@ -177,7 +177,7 @@ Status HeapSimulator::RunComputation( } // Sort to get a deterministic iteration order. std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), - [](const LogicalBuffer* x, const LogicalBuffer* y) { + [](const BufferValue* x, const BufferValue* y) { return x->id() < y->id(); }); @@ -188,7 +188,7 @@ Status HeapSimulator::RunComputation( // // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer // that we should assign. - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } @@ -199,12 +199,12 @@ Status HeapSimulator::RunComputation( // we must be the last user of the buffer. bool shared = false; if (options_.may_reuse_operand_buffers) { - for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { + for (const BufferValue* operand_buffer : operand_buffers_to_free) { if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && buffer->instruction()->opcode() != HloOpcode::kCopy && - CanShareOperandBufferWithUser( + points_to_analysis.CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), points_to_analysis)) { + buffer->instruction(), buffer->index())) { VLOG(3) << " Sharing: " << buffer->ToString() << " with " << operand_buffer->ToString(); ShareBuffer(buffer, operand_buffer, instruction); @@ -248,11 +248,11 @@ Status HeapSimulator::RunComputation( // Free buffers that are no longer live. This is the earliest point that we // can de-allocate; right after the last use of the buffer. - for (const LogicalBuffer* buffer : dead_buffers_to_free) { + for (const BufferValue* buffer : dead_buffers_to_free) { VLOG(3) << " Freeing dead: " << buffer->ToString(); Free(buffer, instruction); } - for (const LogicalBuffer* buffer : operand_buffers_to_free) { + for (const BufferValue* buffer : operand_buffers_to_free) { VLOG(3) << " Freeing operand: " << buffer->ToString(); Free(buffer, instruction); } @@ -261,10 +261,10 @@ Status HeapSimulator::RunComputation( // Any remaining live buffers must be entry parameters or output source // buffers, which had a nullptr sentry added. Free them now, in a // deterministic order. - std::vector to_free; + std::vector to_free; to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { - const LogicalBuffer* buffer = buffer_pending.first; + const BufferValue* buffer = buffer_pending.first; const FlatSet& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; @@ -272,10 +272,10 @@ Status HeapSimulator::RunComputation( } std::sort(to_free.begin(), to_free.end(), - [](const LogicalBuffer* x, const LogicalBuffer* y) { + [](const BufferValue* x, const BufferValue* y) { return x->id() < y->id(); }); - for (const LogicalBuffer* buffer : to_free) { + for (const BufferValue* buffer : to_free) { VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); } @@ -285,7 +285,7 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, const Options& options, + const BufferValue::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), @@ -297,7 +297,7 @@ HeapSimulator::HeapSimulator( HeapSimulator::~HeapSimulator() {} -bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { +bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const { // Buffers for constants are ignored unless the alloc_constants option is // set. Also ignore buffers that we're not meant to assign. // @@ -311,7 +311,7 @@ bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { } // Alloc always calls the underlying heap algorithm. -void HeapSimulator::Alloc(const LogicalBuffer* buffer, +void HeapSimulator::Alloc(const BufferValue* buffer, const HloInstruction* instruction) { CHECK(allocated_buffers_.count(buffer) == 0) << "Alloc called on allocated buffer: " << *buffer; @@ -331,7 +331,7 @@ void HeapSimulator::Alloc(const LogicalBuffer* buffer, // buffers whose group liveness has expired. Shared group liveness is tracked // by maintaining a refcount; the Free call on the last buffer in the group // causes Free to be called on the underlying algorithm. -void HeapSimulator::Free(const LogicalBuffer* buffer, +void HeapSimulator::Free(const BufferValue* buffer, const HloInstruction* instruction) { auto shared_it = shared_buffers_.find(buffer); if (shared_it != shared_buffers_.end()) { @@ -362,8 +362,8 @@ void HeapSimulator::Free(const LogicalBuffer* buffer, // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls to // Alloc. The 'shared' buffer must be a previously allocated or shared buffer. // Both 'buffer' and 'shared' will be associated with the same SharedGroup. -void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer, - const LogicalBuffer* shared, +void HeapSimulator::ShareBuffer(const BufferValue* buffer, + const BufferValue* shared, const HloInstruction* instruction) { CHECK_LE(size_fn_(*buffer), size_fn_(*shared)) << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared; @@ -374,7 +374,7 @@ void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer, CHECK(freed_buffers_.count(shared) == 0) << "ShareBuffer called on freed shared buffer: " << *shared; - const LogicalBuffer* canonical = nullptr; + const BufferValue* canonical = nullptr; auto shared_it = shared_buffers_.find(shared); if (shared_it != shared_buffers_.end()) { // The 'shared' buffer already has a group; it might be the canonical, but @@ -408,7 +408,7 @@ HeapSimulator::Result HeapSimulator::Finish() { // collecting statistics, e.g. NoFragmentationStatsHeap. if (!result.chunk_map.empty()) { for (const auto& share_pair : shared_buffers_) { - const LogicalBuffer* buffer = share_pair.first; + const BufferValue* buffer = share_pair.first; std::shared_ptr group = share_pair.second; if (buffer != group->canonical) { // The canonical must already exist in the chunk_map, since we called @@ -437,9 +437,9 @@ HeapSimulator::Result HeapSimulator::Finish() { } void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, - const LogicalBuffer* buffer, + const BufferValue* buffer, const HloInstruction* instruction, - const LogicalBuffer* share_with_canonical) { + const BufferValue* share_with_canonical) { HeapSimulatorTrace::Event* event = debug_trace_.add_events(); event->set_kind(kind); event->set_buffer_id(buffer->id()); @@ -453,14 +453,14 @@ void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, } } -void NoFragmentationStatsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { current_heap_size_ += size; if (current_heap_size_ > max_heap_size_) { max_heap_size_ = current_heap_size_; } } -void NoFragmentationStatsHeap::Free(const LogicalBuffer* buffer, int64 size) { +void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) { current_heap_size_ -= size; } @@ -472,12 +472,12 @@ HeapSimulator::Result NoFragmentationStatsHeap::Finish() { return result; } -void DecreasingSizeRunsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void DecreasingSizeRunsHeap::Alloc(const BufferValue* buffer, int64 size) { SetMode(kAlloc); run_.emplace_back(Op{buffer, size}); } -void DecreasingSizeRunsHeap::Free(const LogicalBuffer* buffer, int64 size) { +void DecreasingSizeRunsHeap::Free(const BufferValue* buffer, int64 size) { CHECK(mode_ != kInit) << "Free called on empty heap: " << *buffer; SetMode(kFree); run_.emplace_back(Op{buffer, size}); @@ -518,7 +518,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() { run_.clear(); } -void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void LazyBestFitHeap::Alloc(const BufferValue* buffer, int64 size) { // Degenerate case: 0-sized buffers are always allocated at offset 0. if (size == 0) { result_.chunk_map.emplace(buffer, Chunk{0, 0}); @@ -586,7 +586,7 @@ void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) { result_.chunk_map.emplace(buffer, Chunk{kLazyAllocOffset, size}); } -void LazyBestFitHeap::Free(const LogicalBuffer* buffer, int64 size) { +void LazyBestFitHeap::Free(const BufferValue* buffer, int64 size) { auto alloc_it = result_.chunk_map.find(buffer); CHECK(alloc_it != result_.chunk_map.end()) << "Free called on non-allocated buffer: " << *buffer; diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 636f19dd39f09721bd82fc4b44785f196f281ad7..8b2b43a37a5c41d334e5338c6a6fad160f03a51e 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -21,11 +21,12 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -43,7 +44,7 @@ class HeapAlgorithm; // don't need to return the assignment of buffer offsets until the very end. class HeapSimulator { public: - // Chunk represents a contiguous piece of memory. Each LogicalBuffer will be + // Chunk represents a contiguous piece of memory. Each BufferValue will be // associated with a chunk in the assignment result. struct Chunk { int64 offset; @@ -55,7 +56,7 @@ class HeapSimulator { // Result represents the result of the heap simulation. struct Result { // The assignment of buffers to chunks. - tensorflow::gtl::FlatMap chunk_map; + tensorflow::gtl::FlatMap chunk_map; // The total size in bytes of the heap, containing all assigned chunks. int64 heap_size = 0; @@ -81,7 +82,7 @@ class HeapSimulator { bool alloc_constants; // If 'buffers_to_assign' is provided, only those buffers are assigned // offsets, otherwise all buffers defined by the instructions are assigned. - const tensorflow::gtl::FlatSet* buffers_to_assign; + const BufferValueFlatSet* buffers_to_assign; }; // Run the heap simulation with the given algorithm, assuming the given @@ -97,7 +98,7 @@ class HeapSimulator { std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, + const BufferValue::SizeFunction& size_fn, const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' @@ -109,7 +110,7 @@ class HeapSimulator { const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, + const BufferValue::SizeFunction& size_fn, const Options& options = Options()); private: @@ -118,7 +119,7 @@ class HeapSimulator { // be run recursively. I.e. the simulation is run over the whole module. HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, const Options& options, + const BufferValue::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence); ~HeapSimulator(); @@ -127,21 +128,21 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis); - bool IgnoreBuffer(const LogicalBuffer* buffer) const; - void Alloc(const LogicalBuffer* buffer, const HloInstruction* instruction); - void Free(const LogicalBuffer* buffer, const HloInstruction* instruction); - void ShareBuffer(const LogicalBuffer* buffer, const LogicalBuffer* shared, + bool IgnoreBuffer(const BufferValue* buffer) const; + void Alloc(const BufferValue* buffer, const HloInstruction* instruction); + void Free(const BufferValue* buffer, const HloInstruction* instruction); + void ShareBuffer(const BufferValue* buffer, const BufferValue* shared, const HloInstruction* instruction); Result Finish(); void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, - const LogicalBuffer* buffer, + const BufferValue* buffer, const HloInstruction* instruction, - const LogicalBuffer* shared_with_canonical); + const BufferValue* shared_with_canonical); const std::unique_ptr no_fragmentation_stats_; const std::unique_ptr algorithm_; - const LogicalBuffer::SizeFunction size_fn_; + const BufferValue::SizeFunction size_fn_; const Options options_; const SequentialHloOrdering::HloModuleSequence* module_sequence_; @@ -160,15 +161,15 @@ class HeapSimulator { // The shared_buffers_ map associates each shared buffer (including the // canonical) to its SharedGroup control block. struct SharedGroup { - const LogicalBuffer* canonical = nullptr; + const BufferValue* canonical = nullptr; int64 refcount = 0; }; - tensorflow::gtl::FlatMap> + tensorflow::gtl::FlatMap> shared_buffers_; // Hold some sets for error-checking the sequence of Alloc and Free calls. - tensorflow::gtl::FlatSet allocated_buffers_; - tensorflow::gtl::FlatSet freed_buffers_; + tensorflow::gtl::FlatSet allocated_buffers_; + tensorflow::gtl::FlatSet freed_buffers_; // Debugging information filled in while the heap simulator runs. HeapSimulatorTrace debug_trace_; @@ -186,10 +187,10 @@ class HeapAlgorithm { virtual ~HeapAlgorithm() = default; // Alloc allocates a buffer of 'size' bytes. - virtual void Alloc(const LogicalBuffer* buffer, int64 size) = 0; + virtual void Alloc(const BufferValue* buffer, int64 size) = 0; // Free de-allocates a previously allocated buffer. - virtual void Free(const LogicalBuffer* buffer, int64 size) = 0; + virtual void Free(const BufferValue* buffer, int64 size) = 0; // Finish collects the buffer offset assignment results. Free may only be // called once, after the Alloc and Free calls. @@ -205,8 +206,8 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { NoFragmentationStatsHeap() = default; ~NoFragmentationStatsHeap() override = default; - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: @@ -223,14 +224,14 @@ class DecreasingSizeRunsHeap : public HeapAlgorithm { : algorithm_(std::move(algorithm)) {} ~DecreasingSizeRunsHeap() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: // A single Alloc or Free operation that we've buffered in run_. struct Op { - const LogicalBuffer* buffer; + const BufferValue* buffer; int64 size; }; @@ -266,8 +267,8 @@ class LazyBestFitHeap : public HeapAlgorithm { LazyBestFitHeap(int64 alignment) : alignment_(alignment) {} ~LazyBestFitHeap() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index fd56a603bb6f849b1c1f1578fe7395d9b372e2d5..6271652412c2979ff926702f12722102344b0dfb 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -39,7 +39,7 @@ const char kFree[] = "Free"; const char kFinish[] = "Finish"; // CallSequence records a sequence of Alloc/Free/Finish calls. -using CallSequence = std::vector>; +using CallSequence = std::vector>; // HeapCallRecorder is a dummy heap algorithm that simply records its calls. class HeapCallRecorder : public HeapAlgorithm { @@ -47,7 +47,7 @@ class HeapCallRecorder : public HeapAlgorithm { explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {} ~HeapCallRecorder() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override { + void Alloc(const BufferValue* buffer, int64 size) override { calls_->emplace_back(kAlloc, buffer); // Instead of assigning a real offset, we set the cardinality of the Alloc // call. This isn't a valid assignment, but allows us to easily test for @@ -55,7 +55,7 @@ class HeapCallRecorder : public HeapAlgorithm { const int64 offset = result_.chunk_map.size(); result_.chunk_map.emplace(buffer, Chunk{offset, size}); } - void Free(const LogicalBuffer* buffer, int64 size) override { + void Free(const BufferValue* buffer, int64 size) override { calls_->emplace_back(kFree, buffer); } Result Finish() override { @@ -118,7 +118,7 @@ class HeapSimulatorTracker { // Hack the size_fn so that it returns a decreasing value as we step through // the sequence. This lets us ensure the Alloc calls are in the sequence - // order. The Free calls are sorted by LogicalBuffer.id, which is at least + // order. The Free calls are sorted by BufferValue.id, which is at least // deterministic. auto size_fn = [&reverse_position](const BufferValue& buffer) { return reverse_position[buffer.instruction()]; @@ -133,8 +133,8 @@ class HeapSimulatorTracker { HloModule* module() { return module_.get(); } // Returns the buffer defined at the given instruction and index. - const LogicalBuffer* BufferAt(const HloInstruction* instruction, - const ShapeIndex& index) const { + const BufferValue* BufferAt(const HloInstruction* instruction, + const ShapeIndex& index) const { return points_to_analysis_->GetBufferDefinedAt(instruction, index) .ConsumeValueOrDie(); } @@ -150,8 +150,8 @@ class HeapSimulatorTracker { const ShapeIndex& index_a, const HloInstruction* instruction_b, const ShapeIndex& index_b) { - const LogicalBuffer* a = BufferAt(instruction_a, index_a); - const LogicalBuffer* b = BufferAt(instruction_b, index_b); + const BufferValue* a = BufferAt(instruction_a, index_a); + const BufferValue* b = BufferAt(instruction_b, index_b); EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset) << *a << ", " << *b; } @@ -525,7 +525,7 @@ TEST_F(HeapSimulatorTest, WholeModule) { // Now the final cond less-than buffer is allocated. {kAlloc, tracker.BufferAt(cond_lt, {})}, - // The order of the remaining Free calls is based on the LogicalBuffer.id, + // The order of the remaining Free calls is based on the BufferValue.id, // which is deterministic, but not obvious. {kFree, tracker.BufferAt(param, {})}, {kFree, tracker.BufferAt(param, {0})}, @@ -547,40 +547,40 @@ TEST_F(HeapSimulatorTest, WholeModule) { class HeapAlgorithmTestBase : public ::testing::Test { protected: HeapAlgorithmTestBase() : builder_("heap_simulator_test") { - buffer_a_ = DummyLogicalBuffer(); - buffer_b_ = DummyLogicalBuffer(); - buffer_c_ = DummyLogicalBuffer(); - buffer_d_ = DummyLogicalBuffer(); - buffer_e_ = DummyLogicalBuffer(); - buffer_f_ = DummyLogicalBuffer(); - buffer_g_ = DummyLogicalBuffer(); - buffer_h_ = DummyLogicalBuffer(); - buffer_i_ = DummyLogicalBuffer(); + buffer_a_ = DummyBufferValue(); + buffer_b_ = DummyBufferValue(); + buffer_c_ = DummyBufferValue(); + buffer_d_ = DummyBufferValue(); + buffer_e_ = DummyBufferValue(); + buffer_f_ = DummyBufferValue(); + buffer_g_ = DummyBufferValue(); + buffer_h_ = DummyBufferValue(); + buffer_i_ = DummyBufferValue(); } ~HeapAlgorithmTestBase() override {} - const LogicalBuffer* buffer_a_; - const LogicalBuffer* buffer_b_; - const LogicalBuffer* buffer_c_; - const LogicalBuffer* buffer_d_; - const LogicalBuffer* buffer_e_; - const LogicalBuffer* buffer_f_; - const LogicalBuffer* buffer_g_; - const LogicalBuffer* buffer_h_; - const LogicalBuffer* buffer_i_; + const BufferValue* buffer_a_; + const BufferValue* buffer_b_; + const BufferValue* buffer_c_; + const BufferValue* buffer_d_; + const BufferValue* buffer_e_; + const BufferValue* buffer_f_; + const BufferValue* buffer_g_; + const BufferValue* buffer_h_; + const BufferValue* buffer_i_; private: - // Create a dummy LogicalBuffer to pass to the heap algorithm. - const LogicalBuffer* DummyLogicalBuffer() { - const LogicalBuffer::Id id = buffers_.size(); + // Create a dummy BufferValue to pass to the heap algorithm. + const BufferValue* DummyBufferValue() { + const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - buffers_.emplace_back(MakeUnique(const0, ShapeIndex{}, id)); + buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); return buffers_.back().get(); } HloComputation::Builder builder_; - std::vector> buffers_; + std::vector> buffers_; }; class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {}; diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils.h b/tensorflow/compiler/xla/service/hlo_casting_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..7f73bba036534a62a70a80431236cffa766c9b38 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils.h @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Casting utilitiy functions for HLO instructions. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ + +#include +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +class HloInstruction; + +template +using EnableIfDerivedFromHlo = + typename std::enable_if::value>::type; + +// TODO(b/93238915): Switch implementation from C++'s dynamic_cast to LLVM-like +// RTTI if it turns out to be a performance issue. +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr or runtime information does not match. +// +// Similar to LLVM's cast. +template * = nullptr> +const T* Cast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + const T* casted = dynamic_cast(instruction); + CHECK(casted != nullptr); + return casted; +} + +// Non-const overload of Cast. +template * = nullptr> +T* Cast(HloInstruction* instruction) { + return const_cast( + Cast(const_cast(instruction))); +} + +// Works just like the Cast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's cast_or_null. +template * = nullptr> +const T* CastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? Cast(instruction) : nullptr; +} + +// Non-const overload of CastOrNull. +template * = nullptr> +T* CastOrNull(HloInstruction* instruction) { + return const_cast( + CastOrNull(const_cast(instruction))); +} + +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr, returns nullptr if runtime information does not match. +// +// Similar to LLVM's dyn_cast. +template * = nullptr> +const T* DynCast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + return dynamic_cast(instruction); +} + +// Non-const overload of DynCast. +template * = nullptr> +T* DynCast(HloInstruction* instruction) { + return const_cast( + DynCast(const_cast(instruction))); +} + +// Works just like the DynCast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's dyn_cast_or_null. +template * = nullptr> +const T* DynCastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? DynCast(instruction) : nullptr; +} + +// Non-const overload of DynCastOrNull. +template * = nullptr> +T* DynCastOrNull(HloInstruction* instruction) { + return const_cast( + DynCastOrNull(const_cast(instruction))); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a3364275409122254bf99b40a7d2fcbb2d7564cc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class DummyInstruction : public HloInstruction { + public: + DummyInstruction() + : HloInstruction(HloOpcode::kConstant, ShapeUtil::MakeShape(F32, {})) {} +}; + +class AnotherDummyInstruction : public HloInstruction { + public: + AnotherDummyInstruction() + : HloInstruction(HloOpcode::kParameter, ShapeUtil::MakeShape(F32, {})) {} +}; + +TEST(HloCastingUtilsTest, CastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(Cast(null), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastOrNullDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = CastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(DynCast(null), ""); +} + +TEST(HloCastingUtilsTest, DynCastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = DynCastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h new file mode 100644 index 0000000000000000000000000000000000000000..658643b427a9625fac1166151a89cbd669f817d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_clone_context.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +class HloInstruction; +class HloComputation; +class HloModule; + +// Data structure used to track the cloning of HloInstruction and HloComputation +// objects. +class HloCloneContext { + public: + // Creates a new HloCloneContext object to clone HloInstruction and + // HloComputation objects to be added to the module specified as argument. + // The suffix string will be appended to computation names. + explicit HloCloneContext(HloModule* module, const string& suffix = "") + : module_(module), suffix_(suffix) {} + + HloModule* module() const { return module_; } + + const string& suffix() const { return suffix_; } + + void MapInstruction(const HloInstruction* old_instruction, + HloInstruction* new_instruction) { + instructions_[old_instruction] = new_instruction; + } + + void MapComputation(const HloComputation* old_computation, + HloComputation* new_computation) { + computations_[old_computation] = new_computation; + } + + // Finds the new instruction mapped to its old copy, or return nullptr in case + // it is not found. + HloInstruction* FindInstruction(const HloInstruction* old_instruction) const { + return FindOrDefault(instructions_, old_instruction, nullptr); + } + + // Finds the new computation mapped to its old copy, or return nullptr in case + // it is not found. + HloComputation* FindComputation(const HloComputation* old_computation) const { + return FindOrDefault(computations_, old_computation, nullptr); + } + + // Retrieves the new instruction mapped to its old copy, or fail if not found. + HloInstruction* GetInstruction(const HloInstruction* old_instruction) const { + return FindOrDie(instructions_, old_instruction); + } + + // Retrieves the new computation mapped to its old copy, or fail if not found. + HloComputation* GetComputation(const HloComputation* old_computation) const { + return FindOrDie(computations_, old_computation); + } + + const tensorflow::gtl::FlatMap& + cloned_instructions() const { + return instructions_; + } + + const tensorflow::gtl::FlatMap& + cloned_computations() const { + return computations_; + } + + private: + HloModule* module_; + string suffix_; + tensorflow::gtl::FlatMap + instructions_; + tensorflow::gtl::FlatMap + computations_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 17e43c3cb826aaba584ca5652bcdcb8cb829cb36..b158f44923982642615543cebbd54f00596e2641 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -64,7 +64,7 @@ HloComputation::HloComputation( const string& name, int parameter_count, std::vector>* instructions, HloInstruction* root_instruction, HloInstruction* fusion_instruction) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), unique_id_(-1), root_instruction_(root_instruction), fusion_instruction_(fusion_instruction) { @@ -315,6 +315,42 @@ void ComputeComputationPostOrder( } } +std::list ComputeInstructionPostOrder( + HloInstruction* root, tensorflow::gtl::FlatSet* visited) { + std::list post_order; + std::vector> dfs_stack; + dfs_stack.emplace_back(root, false); + while (!dfs_stack.empty()) { + const auto current = dfs_stack.back(); + if (current.second) { + dfs_stack.pop_back(); + if (!visited->insert(current.first).second) { + continue; + } + post_order.push_back(current.first); + } else { + if (visited->count(current.first)) { + dfs_stack.pop_back(); + continue; + } + dfs_stack.back().second = true; + + // Add the operands to the stack in reverse order so the first operand is + // processed first. This will produce a more natural ordering and a nicer + // result for thigns like HLO stringification. + const auto& operands = current.first->operands(); + for (int64 i = operands.size() - 1; i >= 0; --i) { + dfs_stack.emplace_back(operands[i], false); + } + + for (HloInstruction* op : current.first->control_predecessors()) { + dfs_stack.emplace_back(op, false); + } + } + } + return post_order; +} + } // namespace std::list HloComputation::MakeInstructionPostOrder() const { @@ -328,9 +364,9 @@ std::list HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - post_order.splice(post_order.end(), - InstructionPostOrderer::GetOrder(instruction.get(), - &added_instructions)); + post_order.splice( + post_order.end(), + ComputeInstructionPostOrder(instruction.get(), &added_instructions)); } } post_order.splice(post_order.end(), trace_instructions); @@ -365,25 +401,38 @@ std::list HloComputation::MakeEmbeddedComputationsList() string HloComputation::ToString(const HloPrintOptions& options) const { std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << " "; } - if (options.print_percent()) { - s << "%"; + + if (!options.is_in_nested_computation()) { + if (options.print_percent()) { + s << "%"; + } + s << name() << " "; } - s << name(); + if (options.print_program_shape()) { - s << " " << ShapeUtil::HumanString(ComputeProgramShape()); - } - s << " {\n"; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { - for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << ShapeUtil::HumanString(ComputeProgramShape()) << " "; + } + s << "{\n"; + { + // Print the instructions in this computation. + HloPrintOptions new_options = options; + new_options.set_indent_amount(options.indent_amount() + 1) + .set_is_in_nested_computation(true); + CanonicalNameMap name_map; + for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (int i = 0; i < new_options.indent_amount(); i++) { + s << " "; + } + s << (instruction == root_instruction_ ? "ROOT " : "") + << instruction->ToStringWithCanonicalNameMap(new_options, &name_map) + << "\n"; } - s << " " << (instruction == root_instruction_ ? "ROOT " : "") - << instruction->ToString(options) << "\n"; } + for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << " "; } s << "}"; return s.str(); @@ -407,27 +456,37 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( - HloModule* module, const HloComputationProto& proto, + const HloComputationProto& proto, const tensorflow::gtl::FlatMap& computation_map) { - std::vector> instructions; tensorflow::gtl::FlatMap instruction_map; + tensorflow::gtl::FlatMap to_proto_id; + std::vector> instructions; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { TF_ASSIGN_OR_RETURN( std::unique_ptr instruction, - HloInstruction::CreateFromProto(module, instruction_proto, - instruction_map, computation_map)); + HloInstruction::CreateFromProto(instruction_proto, instruction_map, + computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); instruction_map[instruction_proto.id()] = instruction.get(); + to_proto_id[instruction.get()] = instruction_proto.id(); instructions.push_back(std::move(instruction)); } TF_RET_CHECK(proto.root_id() != -1); TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); HloInstruction* root = instruction_map.at(proto.root_id()); + + // Sort the instructions in the proto id's order. + std::sort(instructions.begin(), instructions.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + return WrapUnique(new HloComputation(proto.name(), parameter_count, &instructions, root, /*fusion_instruction=*/nullptr)); @@ -729,22 +788,21 @@ Status HloComputation::Accept( } std::unique_ptr HloComputation::Clone( - const string& suffix, HloModule* module, - HloInstruction::CloneMap* clone_map) { + const string& suffix, HloCloneContext* context) { return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - module, clone_map, suffix); + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloModule* module, HloInstruction::CloneMap* clone_map, - const string& suffix) { - HloInstruction::CloneMap local_clone_map; - if (clone_map == nullptr) { - clone_map = &local_clone_map; + HloCloneContext* context, const string& suffix) { + std::unique_ptr context_ptr; + if (context == nullptr) { + context_ptr = MakeUnique(parent(), suffix); + context = context_ptr.get(); } // Look up instr in the replacements map, and return either the replacement, @@ -769,18 +827,18 @@ std::unique_ptr HloComputation::CloneWithReplacements( } std::vector> instructions; - std::unique_ptr new_instr = nullptr; + std::unique_ptr new_instr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); CHECK_NE(replaced_operand, nullptr) - << "Replacements map specifies to leave out " << operand->ToString() - << ", but it is used by " << instr->ToString() << "."; - new_operands.push_back(FindOrDie(*clone_map, replaced_operand)); + << "replacements map tried to eliminate a used instruction " + << operand->ToString() << ", used by " << instr->ToString(); + new_operands.push_back(context->GetInstruction(replaced_operand)); } - new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands, - module, clone_map); + new_instr = + instr->CloneWithNewOperands(instr->shape(), new_operands, context); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -788,22 +846,23 @@ std::unique_ptr HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction()))); + /*root_instruction=*/context->GetInstruction( + replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(*clone_map, instr); + HloInstruction* new_instr = context->GetInstruction(instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - CHECK_NE(replaced_successor, nullptr) - << "Replacements map specifies to leave out " << successor->ToString() - << ", but it is control-depended-on by " << instr->ToString() << "."; - - TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(*clone_map, replaced_successor))); + // successor may not have been remapped, because it might have been + // removed by the replacements map. + if (replaced_successor != nullptr) { + TF_CHECK_OK(new_instr->AddControlDependencyTo( + context->GetInstruction(replaced_successor))); + } } } - + context->MapComputation(this, result.get()); // We cloned the elements of 'replacements', so they're all going to be // destroyed. HloInstructions need to be detached from their operands before // they're destroyed, otherwise they stick around in the operands' users lists @@ -813,7 +872,6 @@ std::unique_ptr HloComputation::CloneWithReplacements( new_instr->DetachFromOperands(); } } - return result; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 98983556256cec01759f924c7d02993cbe18c891..0da4a305f3d5d694a1918fed294337100b0a27fd 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -49,9 +50,20 @@ class HloModule; // Describes a computation at the HLO level. // -// An HloComputation contains a directed acyclic graph of HLO instructions. The -// computation has a single root instruction which produces the output of the -// computation. +// You can think of an HloComputation like a function. It has some inputs +// (parameters) and returns exactly one value (the value of its root node). If +// you want to return multiple values, you can return a tuple. +// +// The instructions inside of a computation do not have an explicit total order. +// Instead, they have a partial order determined by their data and control +// dependencies. +// +// An HloModule contains one "entry computation" -- this is like main() in a C +// program. Every other computation inside of a module is attached to one or +// more HloInstructions, as a "nested computation". For example, the kMap +// instruction has a nested computation and "applies" it to every element of its +// input, elementwise. (That is, the input [x, y, z] is transformed to [f(x), +// f(y), f(z)].) class HloComputation { public: // Builder class for HloComputation. @@ -157,14 +169,12 @@ class HloComputation { // Creates a computation from the given proto. Arguments: // - // module: the module which will contain the computation. The newly created - // computation is *not* added to the module, however. // proto: the proto to convert from. // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. static StatusOr> CreateFromProto( - HloModule* module, const HloComputationProto& proto, + const HloComputationProto& proto, const tensorflow::gtl::FlatMap& computation_map); // Gets the instructions in this computation. @@ -291,17 +301,11 @@ class HloComputation { const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // - // If the module pointer is not nullptr, then the cloned computations will be - // added to this module in order to support deep cloning. Otherwise the module - // of the computation is used. - // - // If clone_map is not nullptr, then each original instruction that is cloned - // will be inserted and map to its clone. clone_map should not already contain - // any of the instructions to clone. - std::unique_ptr Clone( - const string& suffix = "clone", HloModule* module = nullptr, - HloInstruction::CloneMap* clone_map = nullptr); + // If the clone context is specified, it will be populated with the cloned + // object mappings, and its module() will be used to add new computations + // into. + std::unique_ptr Clone(const string& suffix = "clone", + HloCloneContext* context = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -311,9 +315,7 @@ class HloComputation { std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloModule* module = nullptr, - HloInstruction::CloneMap* clone_map = nullptr, - const string& suffix = "clone"); + HloCloneContext* context = nullptr, const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 7b7588f4ba9aa622677db6f9d5022cc8cc029e04..25469a54c48f4f5cab478aba929f1cc18de8b81f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -550,6 +550,108 @@ TEST_F(HloComputationTest, Reachability) { EXPECT_FALSE(reachability->IsReachable(constant2, copy)); } +TEST_F(HloComputationTest, Stringification) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(computation->ToString(options), + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + +TEST_F(HloComputationTest, StringificationIndent) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = + HloPrintOptions().set_print_metadata(false).set_indent_amount(2); + EXPECT_EQ(computation->ToString(options), + R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"); +} + +TEST_F(HloComputationTest, StringificationCanonical) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(computation->ToString(options), + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); + + options = HloPrintOptions().Canonical(); + EXPECT_EQ(computation->ToString(options), R"(TransposeDot { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 7b552ee5b1798c4c7e24884a392c5982d7fb17ff..5d05ccfc0b223d8749a2577ba1bf96b1ab3e761b 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_strides[] = {1, 1, 1, 1, 1}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 44e4f75f75b275653e1a07111943843fc6f78b33..92a66681a95afcb4531158c133e36ed562e6841d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -142,19 +142,25 @@ Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) { } Status HloCostAnalysis::HandleParameter(const HloInstruction*) { + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleConstant(const HloInstruction*) { + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { // GetTupleElement forwards a pointer and does not touch each element in the // output. + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } @@ -166,7 +172,8 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { + current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2; return Status::OK(); } @@ -329,6 +336,7 @@ Status HloCostAnalysis::HandleSelectAndScatter( Status HloCostAnalysis::HandleBitcast(const HloInstruction*) { // A bitcast does no computation and touches no memory. current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } @@ -379,6 +387,10 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { auto lhs = convolution->operand(0); auto rhs = convolution->operand(1); @@ -555,11 +567,13 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) { } Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { - // We can't do anything sane with CustomCalls, since we don't know what they - // do, and returning an error status will stop iteration over this - // computation, which is probably also not what we want. So just punt and - // return OK. This will cause all of the properties to be reported as 0, - // which is fine. + // Mark applicable fields as "unknown", since we don't know what CustomCall + // does. This is better than returning an error, which would stop iteration, + // and therefore would prevent us from getting *any* stats for a computation + // which contains a CustomCall. + current_properties_[kOptimalSecondsKey] = -1; + current_properties_[kBytesAccessedKey] = -1; + current_properties_[kFlopsKey] = -1; current_should_compute_bottleneck_time_ = false; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index d17678d20f2a23fd98d18b77d5fb25853901a789..0d66736fe1d0677d13a63ede7a203d6ac20c76f5 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -97,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleGenerateToken(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 16fdda8a8b9ade09ea31cda1f4cf5e8ff2c0a081..72adf09c83ee0e419996be663de1df5651e407ed 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -460,5 +460,20 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { EXPECT_EQ(analysis.flop_count(), 1472); } +TEST_F(HloCostAnalysisTest, Slice) { + // Test the analysis on a slice. + XlaBuilder builder("slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.Slice(x, {0}, {1}, {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index ed3b654851ab9311ef1fa8278b7acfa987bb294c..0fb65c845a6d4407c81171f6c1569fee98b1d16d 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -162,6 +162,17 @@ StatusOr MakeConcatHlo(ArraySlice operands, HloInstruction::CreateConcatenate(concat_shape, operands, dimension)); } +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN( + Shape dot_shape, + ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); + return computation->AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index c9a7361a6af0c2a0839c59a0ea695ec1b9a98bd4..49b1402d689a74874e34423a1832a0b6aa15f469 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -97,6 +97,11 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, StatusOr MakeConcatHlo( tensorflow::gtl::ArraySlice operands, int64 dimension); +// Creates a Dot HLO instruction and adds it to the computation containing `lhs` +// and `rhs` (both must be in the same computation). +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers); + // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of // these add all the instructions they generate into the computation containing diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 3b22c93733af293e4d73a2b1b3ac8822dec6d5f5..a0ee8896230d6dcacb5a8eb607fc00ae5226cfa5 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -26,12 +26,14 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { @@ -40,16 +42,16 @@ namespace { // Find and combine identical constants. Constants are identical if they have // the same type and value. -bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { - bool changed = false; - +StatusOr CombineConstants(HloComputation* computation, + bool is_layout_sensitive) { + TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, "")); // Map from ShortDebugString of the layoutless shape of the constant to the // set of constant instructions with that shape. Layoutless shape is used to // bin possible common constants together to reduce number of constant // comparisons. If we end up having too many constant comparisons, a more // precise binning might have to be used. std::multimap constants; - + int64 combined = 0; auto inst_it = computation->instructions().begin(); while (inst_it != computation->instructions().end()) { HloInstruction* instruction = *inst_it; @@ -69,7 +71,8 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto range = constants.equal_range(shape_string); HloInstruction* match = nullptr; for (auto it = range.first; it != range.second; ++it) { - if (instruction->literal() == it->second->literal()) { + if (instruction->literal() == it->second->literal() && + domain_map->InSameDomain(it->second, instruction)) { match = it->second; break; } @@ -80,12 +83,27 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { // Match found, replace this instruction with the one in the multimap. TF_CHECK_OK(instruction->ReplaceAllUsesWith(match)); TF_CHECK_OK(computation->RemoveInstruction(instruction)); - changed = true; + ++combined; } } } + VLOG(4) << "Combined " << combined << " constants in " << computation->name() + << " computation"; + return combined > 0; +} - return changed; +// An instruction is considered to be equivalent to another only if they +// share the exact same set of operands. +int64 CseHash(const HloInstruction* instruction) { + int64 hash = std::hash()(static_cast(instruction->opcode())); + hash = tensorflow::Hash64Combine( + hash, instruction->opcode() == HloOpcode::kGetTupleElement + ? instruction->tuple_index() + : -1); + for (auto operand : instruction->operands()) { + hash = tensorflow::Hash64Combine(hash, operand->unique_id()); + } + return hash; } } // namespace @@ -95,54 +113,53 @@ StatusOr HloCSE::Run(HloModule* module) { const std::function eq_instructions = std::equal_to(); const std::function - eq_computations = std::equal_to(); + eq_computations = [](const HloComputation* lhs, + const HloComputation* rhs) { return *lhs == *rhs; }; + + auto cse_equal = [&](const HloInstruction* lhs, const HloInstruction* rhs) { + return lhs->Identical(*rhs, eq_instructions, eq_computations, + is_layout_sensitive_); + }; + for (auto* computation : module->computations()) { if (only_fusion_computations_ && !computation->IsFusionComputation()) { continue; } - changed |= CombineConstants(computation, is_layout_sensitive_); - - std::list post_order = - computation->MakeInstructionPostOrder(); - std::set removed_instructions; - for (auto instruction : post_order) { - // If the instruction has already been removed by CSE skip over it. - if (removed_instructions.count(instruction) > 0 || - instruction->operand_count() == 0) { + TF_ASSIGN_OR_RETURN(bool combined, + CombineConstants(computation, is_layout_sensitive_)); + changed |= combined; + + // HLO instructions are grouped into equivalency classes by using the + // cse_equal predicate defined above. This set holds a representative + // instruction for each class. + tensorflow::gtl::FlatSet + representatives(/*N=*/computation->instruction_count() + 1, &CseHash, + cse_equal); + for (auto instruction : computation->MakeInstructionPostOrder()) { + // If the instruction has zero operands (constants, parameters, etc.) skip + // over it. + if (instruction->operand_count() == 0) { continue; } - - // Skip instructions which have side effects. - if (instruction->HasSideEffect()) { + // Skip instructions which have side effects or are a domain (which must + // not be CSE-ed). + if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kDomain) { continue; } - // An instruction is considered to be equivalent to another only if they - // share the exact same set of operands. So to find equivalent - // instructions, we just search among instructions which share operand(0) - // of this instruction. - const HloInstruction* operand = instruction->operand(0); - - tensorflow::gtl::InlinedVector - equivalent_instructions; - for (HloInstruction* user : operand->users()) { - if (user != instruction && !user->HasSideEffect() && - user->Identical(*instruction, eq_instructions, eq_computations, - is_layout_sensitive_)) { - equivalent_instructions.push_back(user); - } - } - - // Replace all equivalent instructions with this instruction. - for (HloInstruction* equivalent_instruction : equivalent_instructions) { + auto it = representatives.find(instruction); + if (it != representatives.end()) { + HloInstruction* equivalent_instruction = *it; TF_RETURN_IF_ERROR( - equivalent_instruction->ReplaceAllUsesWith(instruction)); - TF_RETURN_IF_ERROR( - computation->RemoveInstruction(equivalent_instruction)); - removed_instructions.insert(equivalent_instruction); + instruction->ReplaceAllUsesWith(equivalent_instruction)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); changed = true; + continue; } + representatives.insert(instruction); } } return changed; diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index df8853f34f6a72c52d1cde7332ada3809d2f3d96..16db374566c727f1f3efe2a6d419f1f3caf0aaf1 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/types.h" @@ -72,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR0(84.0); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { @@ -104,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { @@ -134,38 +135,53 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, ConstantsSameValueDifferentType) { // Test that constants with the same value but different type are *not* // commoned. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + std::vector constants; + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); // Duplicate the float constant to verify something happens. - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + + const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); + for (int64 i = 0; i < constants.size(); ++i) { + constants[i] = builder.AddInstruction( + HloInstruction::CreateConvert(shape_r0, constants[i])); + } + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, constants[0], constants[1])); + for (int64 i = 2; i < constants.size(); ++i) { + root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, root, constants[i])); + } auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(7, computation->instruction_count()); + EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - EXPECT_EQ(6, computation->instruction_count()); + // CSE will remove both the second float(42.0f) and the corresponding + // convert/cast. + EXPECT_EQ(18, computation->instruction_count()); } TEST_F(HloCseTest, NonscalarConstants) { @@ -469,5 +485,56 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant()))); } +TEST_F(HloCseTest, CompareComputations) { + auto module = ParseHloString(R"( + HloModule m + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + add_computation2 { + add_lhs2 = f32[] parameter(0) + add_rhs2 = f32[] parameter(1) + ROOT add_root2 = f32[] add(add_lhs2, add_rhs2) + } + + ENTRY entry { + p = f32[10]{0} parameter(0) + c = f32[] constant(0) + r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation + r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 + ROOT f2 = (f32[],f32[]) tuple(r1, r2) + })") + .ValueOrDie(); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->operand(0), root->operand(1)); +} + +TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { + // Test that constants with the same value but in different domains (disjoint + // in this case) are not collapsed. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(2, computation->instruction_count()); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(2, computation->instruction_count()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 0c37a8d75f38dabaad886cc9d4adce8ab29ddf18..d0200058683b2db8f5f0469d6c643014881f741e 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -363,7 +363,7 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { bool HloDataflowAnalysis::UpdateConditionalValueSet( HloInstruction* conditional) { CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); - std::vector inputs = { + const InstructionValueSet* const inputs[] = { &GetInstructionValueSet( conditional->true_computation()->root_instruction()), &GetInstructionValueSet( @@ -538,7 +538,7 @@ bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) { bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) { CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile); - std::vector inputs = { + const InstructionValueSet* const inputs[] = { &GetInstructionValueSet(xla_while->while_body()->root_instruction()), &GetInstructionValueSet(xla_while->operand(0))}; if (ssa_form_) { @@ -878,4 +878,135 @@ Status HloDataflowAnalysis::Verify() const { return Status::OK(); } +bool HloDataflowAnalysis::DoesNotUseOperandBuffer( + const HloInstruction* operand, const ShapeIndex& index, + const HloInstruction* user) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + // Iterate through all users of all uses of the fusion parameter value. + // Return false if any uses are detected, returns true otherwise. + const HloValue& value = GetValueDefinedAt(fusion_param, index); + return value.uses().empty(); + } else { + // Return false if no value at 'operand' and 'index' is used at 'user'. + for (const HloValue* value : GetValueSet(operand, index).values()) { + for (const HloUse& use : value->uses()) { + if (use.instruction == user) { + return false; + } + } + } + } + + return true; +} + +bool HloDataflowAnalysis::CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + const Shape& operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + + if (user->opcode() == HloOpcode::kFusion) { + // Get the parameter associated with 'operand'; + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + + const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); + if (value.uses().size() != 1) { + return false; + } + const HloUse& use = value.uses()[0]; + + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is kDot or kConvolution. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return use.instruction == user->fused_expression_root() && + use.operand_number == other_add_operand_index; + } + } + + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + if (user->opcode() == HloOpcode::kCall) { + // Get all uses of value defined by 'operand' at 'operand_index'. + const auto& uses = GetValueDefinedAt(operand, operand_index).uses(); + // Return true iff: + // *) There exists two uses of 'operand'. + // *) One use is by 'user' (caller). + // *) One use is by root instruction of called computation (callee root). + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + const bool found_caller_use = + std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + return use.instruction == user; + }) != uses.end(); + auto* callee_root = user->to_apply()->root_instruction(); + const bool found_elementwise_callee_use = + std::find_if( + uses.begin(), uses.end(), [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); + return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; + } + + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + // + // Multi-output fusion will fail the check here as tuples are not considered + // an elementwise operation. + return user->IsElementwiseOnOperand(user->operand_index(operand)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 7b8a74b096ff48733717e78ada5bb56a28caed72..9868746b6113881949e388cd2a4aa9f610b1fdb7 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -118,6 +118,23 @@ class HloDataflowAnalysis { string ToString() const; + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + // Returns true if 'user' (at 'user_index') can share a buffer with its + // operand 'operand' (at 'operand_index'). Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index) const; + protected: HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 07f69b8e1339fed636e4eb54791941b85e09fd17..db1822ec47a7f52e2c3ef8dcbf433cd787ef75ab 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1873,5 +1873,469 @@ INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); +class HloDataflowAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr dataflow_analysis_; +}; + +class DoesNotUseOperandBufferTest : public HloDataflowAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0)); + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); +} + +class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + NonElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "param0")); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0)); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, neg, {0, 1})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {reverse, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + MultiOutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + + auto copy0 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0)); + auto copy1 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + ElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {exp, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + result, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + result, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {0}, + fusion, {})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {1}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + FusedDynamicUpdateSliceWithConvertCantShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + auto convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape_bf16, gte1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape_bf16, convert1, update, starts)); + + auto convert2 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape, dynamic_update_slice)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {convert2, dynamic_update_slice, starts, update, convert1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can't share with tuple element 1. + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(add_operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = CreateNewModule(); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {})); +} + +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc new file mode 100644 index 0000000000000000000000000000000000000000..78955db0da02f16eb93689db947dc1190ab7049a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainIsolator::RunContext { + public: + RunContext(HloModule* module, HloDomainIsolator* isolator) + : module_(module), isolator_(isolator) {} + + StatusOr Run(); + + private: + // Inserts a kDomain instruction between parent and operand, in case + // the attribute (ie, sharding) values change between instruction and operand. + // Returns the newly inserted kDomain instruction, or nullptr if no kDomain + // instruction was necessary. + StatusOr CreateDomain(HloInstruction* instruction, + HloInstruction* parent, + HloInstruction* operand); + + HloModule* module_; + HloDomainIsolator* isolator_; +}; + +StatusOr HloDomainIsolator::RunContext::CreateDomain( + HloInstruction* instruction, HloInstruction* parent, + HloInstruction* operand) { + HloInstruction* domain = nullptr; + std::unique_ptr domain_instruction = + isolator_->creator_(instruction, operand); + if (domain_instruction != nullptr) { + domain = operand->parent()->AddInstruction(std::move(domain_instruction)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain)); + } + return domain; +} + +StatusOr HloDomainIsolator::RunContext::Run() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); + + int64 added_domains = 0; + for (HloComputation* computation : module_->computations()) { + // Walk in post order and place all the required kDomain instructions. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kDomain) { + continue; + } + for (HloInstruction* operand : instruction->unique_operands()) { + // When applying multiple domains, we could end up stacking more than + // one in one edge, so here we want to build the effective + // (kDomain-less) instruction->operand edge. + HloInstruction* parent = instruction; + while (operand->opcode() == HloOpcode::kDomain) { + parent = operand; + operand = operand->mutable_operand(0); + } + // Check whether a kDomain is necessary between instruction and operand. + TF_ASSIGN_OR_RETURN(HloInstruction * domain, + CreateDomain(instruction, parent, operand)); + if (domain != nullptr) { + VLOG(4) << "New domain: " << domain->ToString(); + ++added_domains; + } + } + } + } + VLOG(3) << "Added " << added_domains << " kDomain instructions"; + if (added_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Isolator"); + } + return added_domains > 0; +} + +HloDomainIsolator::HloDomainIsolator(DomainCreator creator) + : creator_(std::move(creator)) {} + +StatusOr HloDomainIsolator::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h new file mode 100644 index 0000000000000000000000000000000000000000..e0c5718509dabebb7b9307bf764b0ea1ce7369a0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Domain isolation is the task of placing kDomain instructions between HLO +// instructions having different shrading. A kDomain instruction is essentially +// used to break an HLO graph edge connecting two instructions with different +// sharding. If a set of connected instructions have all the same sharding, no +// kDomain instruciton will be placed. +class HloDomainIsolator : public HloPassInterface { + public: + // Creates a new kDomain instruction for the edge between the use instruction + // (the first HloInstruction argument), and the operand instruction (the + // second HloInstruction argument). + // Returns nullptr in case no domain separation is necessary. + using DomainCreator = std::function( + HloInstruction*, HloInstruction*)>; + + explicit HloDomainIsolator(DomainCreator creator); + + tensorflow::StringPiece name() const override { return "domain_isolator"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + DomainCreator creator_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc new file mode 100644 index 0000000000000000000000000000000000000000..ebd5adb5d573ce4b556046f85eb26a6ad59efcb9 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -0,0 +1,176 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +/* static */ StatusOr> HloDomainMap::Create( + HloComputation* computation, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + return std::move(domain_map); +} + +/* static */ StatusOr> HloDomainMap::Create( + HloModule* module, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + for (HloComputation* computation : module->computations()) { + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + } + return std::move(domain_map); +} + +bool HloDomainMap::InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const { + int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1); + int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1); + return domain_id1 >= 0 && domain_id1 == domain_id2; +} + +Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { + TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); + // We only check operands, so we are sure to not process the empty domain from + // both sides. + for (HloInstruction* operand : instruction->unique_operands()) { + if (IsDomainInstruction(operand)) { + auto domain = MakeUnique(); + domain->enter_domains.insert(operand); + domain->exit_domains.insert(instruction); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + } + return Status::OK(); +} + +Status HloDomainMap::Populate(HloComputation* computation) { + for (HloInstruction* instruction : computation->instructions()) { + if (IsDomainInstruction(instruction)) { + // If this is a kDomain of the kind we are currently processing, check + // whether this is an "empty domain". + TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction)); + continue; + } + int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1); + if (domain_id >= 0) { + // We have already processed this instruction. + continue; + } + TF_ASSIGN_OR_RETURN(std::unique_ptr domain, + CreateDomain(instruction)); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + return Status::OK(); +} + +Status HloDomainMap::InsertDomain( + std::unique_ptr domain) { + int64 domain_id = instruction_domains_.size(); + instruction_domains_.push_back(std::move(domain)); + for (HloInstruction* instruction : instruction_domains_.back()->reach_set) { + instruction_to_domain_[instruction] = domain_id; + } + return Status::OK(); +} + +Status HloDomainMap::ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const { + std::vector in_queue; + in_queue.push_back(instruction); + while (!in_queue.empty()) { + HloInstruction* current_instruction = in_queue.back(); + in_queue.pop_back(); + if (domain->reach_set.insert(current_instruction).second) { + // We should not be finding instructions with assigned domain here. + // If we assigned a domain to the instruction, it means that all the + // instructions reached by it, should have a domain as well. + int64 domain_id = + FindOrDefault(instruction_to_domain_, current_instruction, -1); + TF_RET_CHECK(domain_id < 0) + << "Instruction " << current_instruction->ToString() + << " already has domain " << domain_id; + for (HloInstruction* operand : current_instruction->operands()) { + if (IsDomainInstruction(operand)) { + // The reach set instruction is a user of the domain instruction + // (the instruction sees the kDomain as operand). + // IOW the dataflow enters the domain through the kDomain instruction. + domain->enter_domains.insert(operand); + } else { + in_queue.push_back(operand); + } + } + for (HloInstruction* user : current_instruction->users()) { + if (IsDomainInstruction(user)) { + // The reach set instruction is an operand of the domain instruction + // (the instruction sees the kDomain as user). + // IOW the dataflow exits the domain through the kDomain instruction. + domain->exit_domains.insert(user); + } else { + in_queue.push_back(user); + } + } + } + } + return Status::OK(); +} + +StatusOr> HloDomainMap::CreateDomain( + HloInstruction* instruction) const { + auto domain = MakeUnique(); + TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); + domain->instructions = MakeNonDomainInstructions(domain->reach_set); + return std::move(domain); +} + +bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { + if (instruction->opcode() != HloOpcode::kDomain) { + return false; + } + if (!domain_kind_.empty()) { + if (instruction->user_side_metadata().Kind() != domain_kind_) { + return false; + } + // Both user and operand side of the metadata must be of the same kind. + CHECK(instruction->operand_side_metadata().Kind() == domain_kind_) + << "Instruction " << instruction->ToString() + << " has mismatching metadata kinds"; + } + return true; +} + +/* static */ std::vector +HloDomainMap::MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set) { + std::vector instructions; + instructions.reserve(instruction_set.size()); + for (HloInstruction* instruction : instruction_set) { + if (instruction->opcode() != HloOpcode::kDomain) { + instructions.push_back(instruction); + } + } + std::sort(instructions.begin(), instructions.end(), + [](HloInstruction* a, HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + return instructions; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h new file mode 100644 index 0000000000000000000000000000000000000000..e62ef763fb3881ab6030b1f6a66266ac80a3d84d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// The HloDomainMap splits a set of instructions within a module or computation, +// into different domains, separated by kDomain instructions. +// A domain is composed by a set of instructions which can reach each other via +// operand/user edges, without crossing a kDomain insutrction of a given kind. +// A domain never crosses computation boundaries. +class HloDomainMap { + public: + // Creates a new HloDomainMap, creating all the domains within the input + // computation, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create( + HloComputation* computation, string domain_kind); + + // Creates a new HloDomainMap, creating all the domains within the input + // module, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create(HloModule* module, + string domain_kind); + + // Retrieves all the domains the input module or computation are composed by. + const std::vector>& GetDomains() + const { + return instruction_domains_; + } + + // Checks whether two instructions are within the same domain. + bool InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const; + + // Checks whether instruction is a kDomain instruction of the kind we are + // currently processing. + bool IsDomainInstruction(HloInstruction* instruction) const; + + private: + HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} + + // Check if the kDomain instruction is facing (via its operand link) another + // kDomain instruction of the same kind, hence defining an empty domain. + // If that is the case, create the empty domain and call the proper + // normalizer. + Status TryProcessEmptyDomain(HloInstruction* instruction); + + Status Populate(HloComputation* computation); + + // Inserts the provided domain into the ones tracked by this object, + // creating a new domain ID. + Status InsertDomain(std::unique_ptr domain); + + // From the given instruction, epxands operand and user wise, the set of + // instructions which can be reached without crossing a kDomain instruction + // of the kind specified by domain_kind_. + // The domain data structure will be populated with all the reached + // instructions, and the boundaries of the domain, with the kDomain + // instructions encountered while expanding the reach. + Status ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const; + + // Creates a domain data structure using the ExpandDomain() API. + StatusOr> CreateDomain( + HloInstruction* instruction) const; + + // Out of an instruction set, returns a vector of all the ones which are not + // a kDomain kind. + static std::vector MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set); + + string domain_kind_; + std::vector> instruction_domains_; + tensorflow::gtl::FlatMap instruction_to_domain_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..aa0308100a21f109579de75788fce7d242d6a6b0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Cannot include hlo_instruction.h as this file is included from there. +class HloInstruction; + +// The DomainMetadata represents the base class for metadata which can be +// attached to kDomain HLO instructions. +class DomainMetadata { + public: + // A Domain data structure captures all the information about a kDomain + // bounded instruction set. + struct Domain { + // The set of instructions which are reachable from each other via + // operand/user pathways, without crossing a kDomain instruction of a given + // kind. The reach_set can contain kDomain instructions of other kinds, if + // two domains of different kind intersect each other. + tensorflow::gtl::FlatSet reach_set; + + // The same instructions in reach_set, but purged from kDomain instructions. + std::vector instructions; + + // If we consider a graph edge as an arrow oriented from the operand to the + // user, the enter_domains will contain the set of kDomain instructions + // whose dataflow enters the reach set (domain), while the exit_domains + // contains the set of kDomain instructions whose dataflow exit the reach + // set. + tensorflow::gtl::FlatSet enter_domains; + tensorflow::gtl::FlatSet exit_domains; + }; + + virtual ~DomainMetadata() = default; + + // Clones the metadata object. + virtual std::unique_ptr Clone() const = 0; + + // Returns the metadata type. A unique identifier which describes the real + // metadata type. + virtual tensorflow::StringPiece Kind() const = 0; + + // Compares the metadata object with another one and returns true if the + // two matches. + virtual bool Matches(const DomainMetadata& other) const = 0; + + // Returns a string representation of the metadata. + virtual string ToString() const = 0; + + // Given a reachable set (the set of instructions which are reachable from + // each other via user/operand pathways, without crossing a kDomain + // instruciton), makes sure that all of them have metadata attributes which + // are coherent with this metadata object. + virtual Status NormalizeInstructions(const Domain& domain) const = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d06040b0e7c92b03f4cb5481bdee73a0f74f939 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainRemover::RunContext { + public: + RunContext(HloModule* module, HloDomainRemover* remover) + : module_(module), remover_(remover) {} + + StatusOr Run(); + + private: + // Verifies the consistency of the domain, and normalizes the instructions + // within it. + Status VerifyAndNormalizeDomain(const DomainMetadata::Domain& domain); + + HloModule* module_; + HloDomainRemover* remover_; +}; + +Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( + const DomainMetadata::Domain& domain) { + // Verify that the whole kDomain frontier bounding the instruction reach set, + // has matching metadata. + // A kDomain instruction has two sides of metadata, a user facing and an + // operand facing. + // A reachable instruction set can make contact with a kDomain instruction on + // a user facing side (the kDomain is operand of the instruction), or on a + // operand facing side (the kDomain is user of the instruction). + // And depending on the contact side, the proper metadata object + // (user_side_metadata() vs. operand_side_metadata()) needs to be used for + // consistency checks. + const DomainMetadata* ref_metadata = nullptr; + VLOG(4) << "Reach set:"; + for (HloInstruction* instruction : domain.instructions) { + VLOG(4) << " " << instruction->name(); + } + VLOG(4) << " Domains:"; + for (HloInstruction* instruction : domain.enter_domains) { + const DomainMetadata& meta = instruction->user_side_metadata(); + VLOG(4) << " User side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + for (HloInstruction* instruction : domain.exit_domains) { + const DomainMetadata& meta = instruction->operand_side_metadata(); + VLOG(4) << " Operand side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + if (ref_metadata != nullptr) { + VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString(); + TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain)); + } else { + // No kDomain instruction was present within this domain, so call the + // generic normalization functions and have them apply their heuristic. + VLOG(2) << "Applying domain-less normalization"; + TF_RETURN_IF_ERROR(remover_->normalizer_(domain)); + } + return Status::OK(); +} + +StatusOr HloDomainRemover::RunContext::Run() { + VLOG(4) << "Processing metadata domain: '" << remover_->kind_ << "'"; + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Remover"); + + int64 removed_domains = 0; + for (HloComputation* computation : module_->computations()) { + // First create the domain instruciton sets. A domain instruction set is + // the set of instructions whose edges never cross a kDomain instruction. + TF_ASSIGN_OR_RETURN(std::unique_ptr domain_map, + HloDomainMap::Create(computation, remover_->kind_)); + // Verify and normalize every domain populated within the map. + for (auto& domain : domain_map->GetDomains()) { + TF_RETURN_IF_ERROR(VerifyAndNormalizeDomain(*domain)); + } + + // Now remove all the kDomain instructions of the kind specified by the + // remover, that are within the currently processed computation from the + // graph. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + for (HloInstruction* operand : instruction->unique_operands()) { + if (domain_map->IsDomainInstruction(operand)) { + VLOG(5) << "Removing " << operand->name(); + TF_RETURN_IF_ERROR( + operand->ReplaceAllUsesWith(operand->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand)); + ++removed_domains; + } + } + } + HloInstruction* root = computation->root_instruction(); + if (root != nullptr && domain_map->IsDomainInstruction(root)) { + VLOG(5) << "Removing " << root->name(); + computation->set_root_instruction(root->mutable_operand(0)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(root)); + ++removed_domains; + } + } + VLOG(3) << "Removed " << removed_domains << " kDomain instructions of '" + << remover_->kind_ << "' kind"; + if (removed_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Remover"); + } + return removed_domains > 0; +} + +StatusOr HloDomainRemover::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h new file mode 100644 index 0000000000000000000000000000000000000000..0c71dd34fd4d2944037dc965a2c9ad2c592d6e3e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// Removes all the kDomain instructions of a given kind from the input module, +// and calls the normalizer to propagate the properties on the possibly new born +// instructions. +class HloDomainRemover : public HloPassInterface { + public: + // Creates a new HloDomainRemover object tasked at removing all the kDomain + // instructions of a given kind. + // In case a reachable set (the set of instructions within a computation, + // which are mutually reachable via operand/user pathways) has all the + // instructions in it with the same attributes (ie, sharding), a normalizer + // function is tasked at applying attribute normalization on the instructions + // within such domain. + HloDomainRemover( + tensorflow::StringPiece kind, + std::function normalizer) + : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} + + tensorflow::StringPiece name() const override { return "domain_remover"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + string kind_; + std::function normalizer_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5553ddb153f7f1f2e6a790890c11f35e192488c4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -0,0 +1,432 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloDomainTest : public HloTestBase { + protected: + bool FindUserViaDomainPath(HloInstruction* instruction, + HloInstruction* operand) const { + for (HloInstruction* user : operand->users()) { + if (user == instruction) { + return true; + } + if (user->opcode() == HloOpcode::kDomain && + FindUserViaDomainPath(instruction, user)) { + return true; + } + } + return false; + } + + // Checks whether there is a kDomain instruction in the edge between the + // instruction and the operand. + bool HasDomainEdge(HloModule* module, + tensorflow::StringPiece instruction_name, + tensorflow::StringPiece operand_name) { + HloInstruction* instruction = FindInstruction(module, instruction_name); + HloInstruction* operand = FindInstruction(module, operand_name); + CHECK_NE(instruction, nullptr); + CHECK_NE(operand, nullptr); + if (!instruction->IsUserOf(operand)) { + // If instruction is not an immediate user, we must find a path from + // operand to instruction anyway, otherwise there is a corruption. + if (FindUserViaDomainPath(instruction, operand)) { + return true; + } + LOG(FATAL) << "Bad HLO module generated across the '" << instruction_name + << "' and '" << operand_name << "' instructions:\n" + << module->ToString(); + } + return false; + } + + StatusOr> ParseModule( + tensorflow::StringPiece hlo_string) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + return ParseHloString(hlo_string, config); + } +}; + +// Dummy DomainMetadata implementation which create kDomain boundaries around +// HLO instructions with the same metadata().op_name() values. +class OpNameMetadata : public DomainMetadata { + public: + explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {} + + std::unique_ptr Clone() const override { + return MakeUnique(opname_); + } + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override { + const OpNameMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a OpNameMetadata, then it is clearly a no match. + return false; + } + return opname_ == other_ptr->opname_; + } + + string ToString() const override { return opname_; } + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override { + // For the purposes of this test, nothing to do. + return Status::OK(); + } + + static tensorflow::StringPiece KindName() { return "opname"; } + + private: + string opname_; +}; + +// Creator function for OpNameMetadata domains. +std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* operand) { + if (instruction->metadata().op_name() == operand->metadata().op_name()) { + return nullptr; + } + std::unique_ptr operand_side_metadata = + MakeUnique(operand->metadata().op_name()); + std::unique_ptr user_side_metadata = + MakeUnique(instruction->metadata().op_name()); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain) { + // Nothing to do for the particular use this test make of the OpName domains. + return Status::OK(); +} + +TEST_F(HloDomainTest, CheckDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b), sharding={maximal device=1} + d = f32[4] subtract(a, b), sharding={maximal device=1} + e = f32[4] multiply(c, d), sharding={maximal device=1} + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b) + d = f32[4] subtract(a, b) + e = f32[4] multiply(c, d) + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(!isolator_changed); +} + +TEST_F(HloDomainTest, CheckDomainAroundIO) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = (f32[4], u32[]) send(a), channel_id=1, sharding={maximal device=0} + c = () send-done(b), channel_id=1, sharding={maximal device=0} + d = (f32[4], u32[]) recv(), channel_id=2, sharding={maximal device=0} + e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0} + f = f32[4] add(a, e) + g = f32[4] subtract(a, e) + ROOT h = (f32[4], f32[4]) tuple(f, g) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e")); + EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=-1} + b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1} + c = f32[4] add(b, b), sharding={maximal device=-1} + d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=-1} + ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_FALSE(isolator_changed); +} + +TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=0} + b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0} + c = f32[4] add(b, b) + d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=0} + ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_FALSE(remover_changed); + + HloInstruction* add = FindInstruction(module.get(), "c"); + ASSERT_NE(add, nullptr); + auto device = add->sharding_unique_device(); + EXPECT_TRUE(device.has_value()); + EXPECT_EQ(*device, 0); +} + +TEST_F(HloDomainTest, CheckMultiDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(a, b), sharding={maximal device=1} + d = f32[4] subtract(a, c), sharding={maximal device=1}, metadata={op_name="D"} + e = f32[4] multiply(c, d), sharding={maximal device=1}, metadata={op_name="D"} + f = f32[4] add(e, c), sharding={maximal device=1} + ROOT g = (f32[4], f32[4], f32[4]) tuple(c, d, f) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator sharding_isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, + sharding_isolator.Run(module.get())); + EXPECT_TRUE(sharding_isolator_changed); + + HloDomainIsolator opname_isolator(OpNameDomainCreator); + TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, + opname_isolator.Run(module.get())); + EXPECT_TRUE(opname_isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover sharding_remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, + sharding_remover.Run(module.get())); + EXPECT_TRUE(sharding_remover_changed); + + HloDomainRemover opname_remover(OpNameMetadata::KindName(), + OpNameDomainNormalizer); + TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, + opname_remover.Run(module.get())); + EXPECT_TRUE(opname_remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); +} + +TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + infeed = (f32[4], f32[4]) infeed(), + sharding={{maximal device=1}, {maximal device=0}} + gte0 = f32[4] get-tuple-element(infeed), index=0 + gte1 = f32[4] get-tuple-element(infeed), index=1 + copy0 = f32[4] copy(gte0) + copy1 = f32[4] copy(gte1) + ROOT add = f32[4] add(copy0, copy1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "gte0", "infeed")); + EXPECT_TRUE(HasDomainEdge(module.get(), "gte1", "infeed")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); + + // Inject unassigned tuple/gte within the infeed domain, to simulate the + // HLO passes adding unexpected instructions. + // + // infeed + // / \ + // GTE0 GTE1 + // / \ + // COPY0 COPY1 + // \ / + // \ / + // TUPLE + // | + // DOMAIN + HloInstruction* infeed = FindInstruction(module.get(), "infeed"); + ASSERT_NE(infeed, nullptr); + auto infeed_users = infeed->users(); + HloInstruction* new_gte0 = + infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + HloInstruction* new_copy0 = + infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte0->shape(), HloOpcode::kCopy, new_gte0)); + HloInstruction* new_gte1 = + infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed->shape(), 1), infeed, 1)); + HloInstruction* new_copy1 = + infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte1->shape(), HloOpcode::kCopy, new_gte1)); + HloInstruction* new_tuple = infeed->parent()->AddInstruction( + HloInstruction::CreateTuple({new_copy0, new_copy1})); + for (HloInstruction* user : infeed_users) { + TF_EXPECT_OK(infeed->ReplaceUseWith(user, new_tuple)); + } + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + struct Assignment { + HloInstruction* instruction; + int64 device; + } assignments[] = { + {new_gte0, 1}, + {new_copy0, 1}, + {new_gte1, 0}, + {new_copy1, 0}, + }; + for (auto& assignment : assignments) { + auto device = assignment.instruction->sharding_unique_device(); + EXPECT_TRUE(device.has_value()); + EXPECT_EQ(*device, assignment.device); + } + EXPECT_TRUE(new_tuple->has_sharding()); + EXPECT_EQ( + new_tuple->sharding(), + HloSharding::Tuple(new_tuple->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index d236f83aeb9254b9c6e6d04629758ac2c8fd0da3..4ed1508d7067684a15d0fb7d86e69b055bc1333b 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -119,6 +119,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { return false; } + HloCloneContext context(module); bool changed = false; for (auto* computation : module->computations()) { for (auto* hlo : computation->MakeInstructionPostOrder()) { @@ -140,6 +141,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops with embedded computations where it suffices to convert // the embedded computations instead of converting the ops themselves. if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || + opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kSelectAndScatter || @@ -180,7 +182,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); new_hlo = computation->AddInstruction( - hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + hlo->CloneWithNewOperands(shape, new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); new_hlo = ToElementType(new_hlo, eliminate_type_); @@ -189,16 +191,16 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, replace_with_type_); - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - new_shape, new_operands, hlo->GetModule())); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(new_shape, new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); // Convert the elements of the result of `new_hlo` to produce a new // tuple with shape `old_shape`. new_hlo = ConvertTupleElements(new_hlo, old_shape); } else { - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - hlo->shape(), new_operands, hlo->GetModule())); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e7425c8ba790d8b18ba35bc3c3b9227b7a750e7e..e0648e14672c45e9a691fd6a674c9a2cd7605a12 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -52,12 +52,11 @@ namespace xla { namespace { using tensorflow::gtl::ArraySlice; -using tensorflow::gtl::FlatSet; template StatusOr> Compare(const Shape& shape, HloOpcode opcode, - const Literal& lhs_literal, - const Literal& rhs_literal) { + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -95,7 +94,7 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -106,8 +105,8 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, template <> StatusOr> Compare( - const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, - const Literal& rhs_literal) { + const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -125,7 +124,7 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -310,6 +309,35 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( return result; } +StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( + HloOpcode opcode, const Literal& lhs, const Literal& rhs) { + std::unique_ptr lhs_instr = + HloInstruction::CreateConstant(lhs.CloneToUnique()); + std::unique_ptr rhs_instr = + HloInstruction::CreateConstant(rhs.CloneToUnique()); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), + rhs_instr.get()); + auto result = Evaluate(cloned_instruction.get()); + + cloned_instruction->DetachFromOperands(); + return result; +} + +StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( + HloOpcode opcode, const Literal& operand) { + std::unique_ptr operand_instr = + HloInstruction::CreateConstant(operand.CloneToUnique()); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); + auto result = Evaluate(cloned_instruction.get()); + + cloned_instruction->DetachFromOperands(); + return result; +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; @@ -860,6 +888,36 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { return Status::OK(); } +Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { + const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0)); + + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(operand.shape())) + << "broadcast dimensions is of size: " << broadcast->dimensions().size() + << " and rank of operand_to_broadcast is: " + << ShapeUtil::Rank(operand.shape()); + // Checks that operand's dimensions are the same as the broadcast's + // dimensions along the dimensions to be broadcasted. + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == + operand.shape().dimensions(i)); + } + + TF_ASSIGN_OR_RETURN( + evaluated_[broadcast], + operand.Broadcast(broadcast->shape(), broadcast->dimensions())); + + return Status::OK(); +} + +Status HloEvaluator::HandleGenerateToken(HloInstruction* token) { + // Literals cannot represent a TOKEN shape so just create an empty tuple as + // the "result" of the kGenerateToken operation. + // TODO(b/109929053): Add support for TOKENs in Literals. + evaluated_[token] = Literal::MakeTuple({}); + return Status::OK(); +} + Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const auto result_shape = get_tuple_element->shape(); const int64 index = get_tuple_element->tuple_index(); @@ -915,9 +973,10 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { // Attach cloned computation to an empty HLO module so the existing ones are // not modified. HloModule empty_hlo_module("EmptyModuleForFusion", config); + HloCloneContext context(&empty_hlo_module); auto cloned_fused_computation = fusion->fused_instructions_computation()->Clone( - /*suffix=*/"clone_with_layout", &empty_hlo_module); + /*suffix=*/"clone_with_layout", &context); for (auto* instruction : cloned_fused_computation->instructions()) { LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); } @@ -952,8 +1011,8 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* true_computation = conditional->true_computation(); auto* false_computation = conditional->false_computation(); - auto result = Literal::CreateFromShape(conditional->shape()); HloEvaluator embedded_evaluator; + std::unique_ptr result; if (pred.Get({})) { result = embedded_evaluator .Evaluate(*true_computation, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index cc5676ea7b05be6e0b7066bf703d8e48da0133ab..fc2fc9437b238a2e519401b2b121dfbef070e2dc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -108,6 +109,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const std::unordered_map& substitutions); + StatusOr> EvaluateElementwiseBinaryOp( + HloOpcode opcode, const Literal& lhs, const Literal& rhs); + + StatusOr> EvaluateElementwiseUnaryOp( + HloOpcode opcode, const Literal& operand); + protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. @@ -165,6 +172,32 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; + Status HandleBroadcast(HloInstruction* broadcast) override; + + Status HandleGenerateToken(HloInstruction* token) override; + + // Returns the already-evaluated literal result for the instruction. + // A Constant instruction is considered evaluated and its literal will be + // returned directly without looking up the cache. + // Crash with log if the given instruction has not been evaluated previously. + const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { + if (hlo->IsConstant()) { + return hlo->literal(); + } + auto it = evaluated_.find(hlo); + CHECK(it != evaluated_.end()) + << "could not find evaluated value for: " << hlo->ToString(); + return *(it->second); + } + + // Tracks the HLO instruction and its evaluated literal result. + // TODO(b/35950897): have better memory management here to free instructions + // that are no longer a parent for any other subsequent instruction in + // post-orderring. + // Must be cleared for each evaluation. + tensorflow::gtl::FlatMap> + evaluated_; + private: template static StatusOr> ElementWiseUnaryOpImpl( @@ -184,8 +217,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { ShapeUtil::HumanString(operand->shape()).c_str()); } - auto result = Literal::CreateFromShape(shape); - + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { return unary_op(operand_literal.Get(multi_index)); @@ -193,20 +225,6 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return std::move(result); } - // Returns the already-evaluated literal result for the instruction. - // A Constant instruction is considered evaluated and its literal will be - // returned directly without looking up the cache. - // Crash with log if the given instruction has not been evaluated previously. - const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { - if (hlo->IsConstant()) { - return hlo->literal(); - } - auto it = evaluated_.find(hlo); - CHECK(it != evaluated_.end()) - << "could not find evaluated value for: " << hlo->ToString(); - return *(it->second); - } - // Map from a primitive type to its associated (templated) DfsHloVisitor. // Note: the hash function here is only needed because current gcc std::hash // does not specialize for enum types. This should however be fixed in the @@ -215,14 +233,6 @@ class HloEvaluator : public DfsHloVisitorWithDefault { std::hash> typed_visitors_; - // Tracks the HLO instruction and its evaluated literal result. - // TODO(b/35950897): have better memory management here to free instructions - // that are no longer a parent for any other subsequent instruction in - // post-orderring. - // Must be cleared for each evaluation. - tensorflow::gtl::FlatMap> - evaluated_; - // Caches pointers to input literals, assuming they are in post-order. // Literals are not owned by this class, and they must outlive the lifetime of // each invocation to the Evaluate* method. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index cc16446778cbeac5ec4bed110adc9be8620084fe..72eb9930e92c340ab9f42cd563c27507623b2ba7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -82,9 +82,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto element_type = expected->shape().element_type(); if (element_type == F32 || element_type == F64) { ErrorSpec error(aabs); - LiteralTestUtil::ExpectNear(*expected, *result, error); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error)); } else { - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } } @@ -100,7 +100,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } bool use_bfloat16_; @@ -129,7 +129,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { @@ -150,7 +150,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select @@ -175,7 +175,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -262,13 +262,13 @@ TEST_P(HloEvaluatorTest, DoesCosR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), - use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); + use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesSinR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), - use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); + use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesNotR2) { auto operand = @@ -307,7 +307,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies Reshape operation is correctly evaluated. @@ -315,7 +315,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = @@ -333,7 +333,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { result->EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - EXPECT_NEAR(value, literal_clone->Get(rindexes), 0x1.0P-5); + EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); }); } @@ -351,7 +351,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { std::unique_ptr result = Evaluate({}); - LiteralTestUtil::ExpectEqual(*result, *output_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); } TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { @@ -370,7 +370,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { std::unique_ptr result = Evaluate({}); - LiteralTestUtil::ExpectEqual(*result, *output_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); } TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { @@ -392,7 +392,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { auto expected = Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { @@ -413,7 +413,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({100, 200}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { @@ -432,7 +432,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { @@ -452,7 +452,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } PaddingConfig CreatePaddingConfig( @@ -490,7 +490,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto expected = Literal::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { @@ -525,7 +525,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { auto expected = Literal::CreateR4FromArray4D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, NegativePadding2D) { @@ -567,7 +567,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -606,7 +606,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { auto expected_array = MakeUnique>(0, 9); auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank1) { @@ -651,7 +651,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // clang-format on auto expected = Literal::CreateR2FromArray2D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank1AndRank2) { @@ -688,7 +688,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { auto expected = Literal::CreateR1({22.f, 28.f}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank2) { @@ -737,7 +737,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { }); auto expected = Literal::CreateR2FromArray2D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SimpleConv1D) { @@ -785,7 +785,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = Literal::CreateR3FromArray3D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { @@ -847,7 +847,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { // clang-format on auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { @@ -927,7 +927,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { auto expected = Literal::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { @@ -1004,7 +1004,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { auto expected = Literal::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { @@ -1067,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { @@ -1131,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, @@ -1203,7 +1203,7 @@ TEST_P(HloEvaluatorTest, })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; @@ -1248,7 +1248,7 @@ void BM_ReducePrecisely(int num_iters) { HloComputation::Builder b("BM_ReducePrecisely"); HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_ReducePrecisely", VersionedComputationHandle(), config); + HloModule module("BM_ReducePrecisely", config); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 std::vector v(kNumElements, 1.0f); @@ -1319,7 +1319,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { auto expected = Literal::CreateR1({6, 18}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowMax) { @@ -1370,7 +1370,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{6, 7}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd) { @@ -1427,7 +1427,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { @@ -1490,7 +1490,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = Literal::CreateFullWithDescendingLayout(output_dims, 8.0f); - LiteralTestUtil::ExpectEqual(*result_literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); } TEST_P(HloEvaluatorTest, StridedSlice) { @@ -1523,7 +1523,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { {19}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DynamicSlice) { @@ -1556,7 +1556,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { {6, 7, 8}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that the HloEvaluator's implementation goes along with existing @@ -1591,7 +1591,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { {6, 7, 8}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { @@ -1627,7 +1627,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { {5, -6, -7}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SetAndGetTuples) { @@ -1662,7 +1662,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { {5, 6, 7}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { @@ -1703,7 +1703,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { result_inner_literal.get(), }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Reverse) { @@ -1756,7 +1756,7 @@ TEST_P(HloEvaluatorTest, Reverse) { }); // clang-format on - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { @@ -1776,8 +1776,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { add, {{param0, Literal::CreateR1({1, 2, 3, 4}).get()}, {square, Literal::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), - *result.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1800,8 +1800,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { auto result = evaluator.EvaluateWithSubstitutions( add, {{square, Literal::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), - *result.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1823,9 +1823,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1847,9 +1847,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), gather_indices.get()})); + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1872,10 +1872,10 @@ ENTRY main { Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 2}, {2, 1}}); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), gather_indices.get()})); + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1900,9 +1900,9 @@ ENTRY main { {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 0}, {1, 0}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, 1}, {-4, 4}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, @@ -1928,9 +1928,9 @@ ENTRY main { {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 0}, {1, 0}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-2, 2}, {-1, 1}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -1952,9 +1952,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{5}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{5}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -1977,9 +1977,9 @@ ENTRY main { Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR2({{2, 1}, {1, 1}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR3({{{8}}, {{5}}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR3({{{8}}, {{5}}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2000,9 +2000,34 @@ ENTRY main { ParseAndVerifyModule(hlo_text); std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{}, {}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{}, {}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { + const string hlo_text = R"( +HloModule GatherXd + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1} +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr operand = Literal::CreateR1({0, 1, 2}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0}, {1}}, {{2}, {1}}}); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{0, 1}, {2, 1}}), + *Evaluate({operand.get(), gather_indices.get()}))); } // Verifies that HloEvaluator evaluates a HLO instruction that performs diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index f1cb36347850a5af8d9f0cb7b28d05bc7b382030..13f46407e33e36bdbef4c9032630101d6c18268f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -161,36 +161,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleRound(round); } - Status HandleBroadcast(HloInstruction* broadcast) override { - parent_->evaluated_[broadcast] = - Literal::CreateFromShape(broadcast->shape()); - auto output = parent_->evaluated_[broadcast].get(); - const Literal& operand_to_broadcast = - parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); - std::vector broadcast_indices( - ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); - - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand_to_broadcast.shape())) - << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand_to_broadcast.shape()); - // Checks that operand's dimensions are the same as the broadcast's - // dimensions along the dimensions to be broadcasted. - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand_to_broadcast.shape().dimensions(i)); - } - - return output->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; - } - return operand_to_broadcast.Get(broadcast_indices); - }); - } - template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -253,6 +223,29 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleExpm1(HloInstruction* expm1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[expm1], + ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + return std::expm1(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleExpm1(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Expm1"); + } + + Status HandleExpm1(HloInstruction* floor) override { + return HandleExpm1(floor); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -284,6 +277,29 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleLog1p(HloInstruction* expm1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[expm1], + ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + return std::log1p(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleLog1p(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Log1p"); + } + + Status HandleLog1p(HloInstruction* floor) override { + return HandleLog1p(floor); + } + template ::value && @@ -790,7 +806,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = Literal::CreateFromShape(result_shape); + auto result = MakeUnique(result_shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice out_index) { @@ -947,7 +963,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast(result_val); }; - auto result = Literal::CreateFromShape(result_shape); + auto result = MakeUnique(result_shape); TF_RETURN_IF_ERROR(result->PopulateParallel(func)); parent_->evaluated_[conv] = std::move(result); @@ -987,8 +1003,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = Literal::CreateFromShape(dot->shape()); - CHECK_EQ(dnums.lhs_batch_dimensions_size(), dnums.rhs_batch_dimensions_size()); @@ -1014,6 +1028,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector lhs_index(lhs_rank); DimensionVector rhs_index(rhs_rank); + auto result = MakeUnique(dot->shape()); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice result_index) { ElementwiseT result_val = static_cast(0); @@ -1107,7 +1122,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = Literal::CreateFromShape(pad->shape()); + auto result = MakeUnique(pad->shape()); TF_RETURN_IF_ERROR(result->Populate( [&scalar](tensorflow::gtl::ArraySlice multi_index) { return scalar; @@ -1272,7 +1287,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = Literal::CreateFromShape(map->shape()); + auto result = MakeUnique(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); TF_RETURN_IF_ERROR(result->Populate( @@ -1388,8 +1403,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = Literal::CreateFromShape(reduce->shape()); - const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); std::vector arg_dim_steps(arg_dimensions.size()); std::vector arg_dim_counts(arg_dimensions.size()); @@ -1408,6 +1421,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { @@ -1438,11 +1452,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Evaluate computation with specified literal operands. auto curr_val_literal = Literal::CreateR0(curr_val); auto result_val_literal = Literal::CreateR0(result_val); - std::vector args = {result_val_literal.get(), - curr_val_literal.get()}; std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) + embedded_evaluator + .Evaluate( + *function, + {result_val_literal.get(), curr_val_literal.get()}) .ConsumeValueOrDie(); // Clear visit states so that we can use the evaluator again on // the same computation. @@ -1486,7 +1501,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = Literal::CreateFromShape(select_and_scatter->shape()); + auto result = MakeUnique(select_and_scatter->shape()); // Initialize result array with the init value. TF_RETURN_IF_ERROR(result->Populate( @@ -1510,9 +1525,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { int64 rank = ShapeUtil::Rank(operand_literal.shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - DimensionVector source_index(rank); - - std::fill(source_index.begin(), source_index.end(), 0); + DimensionVector source_index(rank, 0); + + // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid + // dynamic memory allocations. + auto curr_val_literal = Literal::CreateR0(ReturnT()); + auto selected_val_literal = Literal::CreateR0(ReturnT()); + auto source_literal_scatter = Literal::CreateR0(ReturnT()); + auto scattered_literal = Literal::CreateR0(ReturnT()); do { // For each element in `source`, we place a window in `operand`. For each // window placement, we iterate inside the window twice: @@ -1536,14 +1556,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val = curr_val; selected_index = operand_index; } - const auto curr_val_literal = Literal::CreateR0(curr_val); - const auto selected_val_literal = - Literal::CreateR0(*selected_val); - - const std::vector args = { - selected_val_literal.get(), curr_val_literal.get()}; + curr_val_literal->Set({}, curr_val); + selected_val_literal->Set({}, *selected_val); std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*select, args) + embedded_evaluator + .Evaluate( + *select, + {selected_val_literal.get(), curr_val_literal.get()}) .ConsumeValueOrDie(); bool selected = !computed_result->Get({}); if (selected) { @@ -1560,14 +1579,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_index->begin())) { auto source = source_literal.Get(source_index); auto scattered = result->Get(operand_index); - const auto source_literal = Literal::CreateR0(source); - const auto scattered_literal = - Literal::CreateR0(scattered); - - const std::vector args = { - source_literal.get(), scattered_literal.get()}; + source_literal_scatter->Set({}, source); + scattered_literal->Set({}, scattered); std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*scatter, args) + embedded_evaluator + .Evaluate(*scatter, + {source_literal_scatter.get(), + scattered_literal.get()}) .ConsumeValueOrDie(); result->Set(operand_index, computed_result->Get({})); // Clear visit states so that the we can use the evaluator again @@ -1607,8 +1625,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = Literal::CreateFromShape(reduce_window->shape()); - // Creates a Shape object from window, for iteration below. std::vector window_dimension_sizes; for (const auto& window_dimension : window.dimensions()) { @@ -1621,6 +1637,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice output_index) { @@ -1639,10 +1656,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Literal::CreateR0(curr_val); const auto result_val_literal = Literal::CreateR0(result_val); - const std::vector args = { - result_val_literal.get(), curr_val_literal.get()}; std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) + embedded_evaluator + .Evaluate( + *function, + {result_val_literal.get(), curr_val_literal.get()}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again @@ -1689,14 +1707,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - // Enable CLZ only for int32 and uint32. + // Enable CLZ only for int32, uint32, int64 and uint64. template < typename NativeT, typename std::enable_if< (std::is_floating_point::value || std::is_integral::value || is_complex_t::value) && !(std::is_same::value || - std::is_same::value)>::type* = nullptr> + std::is_same::value || + std::is_same::value || + std::is_same::value)>::type* = nullptr> Status HandleClz(HloInstruction* clz) { return InvalidArgument("Unsupported type for Clz"); } @@ -1713,6 +1733,18 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value || + std::is_same::value>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], + ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { + return 63 - tensorflow::Log2Floor64(elem_operand); + })); + return Status::OK(); + } + Status HandleClz(HloInstruction* clz) override { return HandleClz(clz); } @@ -1926,17 +1958,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector start(start_indices_typed.begin(), start_indices_typed.end()); - std::vector operand_indices(start.size()); + // Clamp the start indices so the slice is in-bounds w.r.t the operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to officially document different behavior. + for (int64 i = 0; i < start.size(); ++i) { + start[i] = std::min( + std::max(int64{0}, start[i]), + operand_literal.shape().dimensions(i) - result_shape.dimensions(i)); + } - auto result = Literal::CreateFromShape(result_shape); + std::vector operand_indices(start.size()); + auto result = MakeUnique(result_shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); - // Mod is only used here to be consistent with the existing - // backends' behavior. - operand_indices[i] = (multi_index[i] + start[i]) % - operand_literal.shape().dimensions(i); + operand_indices[i] = multi_index[i] + start[i]; } auto result = operand_literal.Get(operand_indices); @@ -1953,23 +1992,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto result = operand_literal.CloneToUnique(); auto start_indices_typed = start_indices_literal.data(); const auto rank = ShapeUtil::Rank(result->shape()); - std::vector start(rank, 0); + std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + // Clamp the update start indices so the slice is in-bounds w.r.t the + // operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. for (int64 i = 0; i < rank; ++i) { - // All other implementations currently wrap-around the index, so this - // should do so as well. - start[i] = (start_indices_typed[i] % result->shape().dimensions(i)); - start[i] += (start[i] < 0) * result->shape().dimensions(i); + start[i] = std::min( + std::max(0, start[i]), + result->shape().dimensions(i) - update_literal.shape().dimensions(i)); } std::vector result_index(rank, 0); auto func = [&](tensorflow::gtl::ArraySlice update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); - // Same as above, wrap-around only to match other implementations' - // semantics. - std::transform(result_index.begin(), result_index.end(), - result->shape().dimensions().begin(), result_index.begin(), - std::modulus()); result->Set(result_index, update_literal.Get(update_index)); return true; @@ -2020,7 +2060,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { @@ -2058,7 +2098,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index a0cb28246d3be541e798e85552436f64a3521f22..eba80c0f199f6224f4b46ac19af482c713585154 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -15,53 +15,33 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -class HloExecutionProfileTest : public HloTestBase { - protected: - static constexpr int64 kInstructionCyclesIndex = 0; - static constexpr int64 kInstructionNameIndex = 19; -}; +using tensorflow::strings::StrCat; +using ::testing::AllOf; +using ::testing::ContainsRegex; -// Splits `lines` into a sequence of lines delimited by newlines and then split -// each of those lines into a sequence of words delimited by spaces. Filter out -// empty words. -std::vector> SplitIntoLinesAndWords( - tensorflow::StringPiece lines) { - std::vector> result; - for (const string& line : tensorflow::str_util::Split(lines, '\n')) { - std::vector words; - for (const string& word : tensorflow::str_util::Split(line, ' ')) { - if (!word.empty()) { - words.push_back(word); - } - } - result.push_back(std::move(words)); - } - - return result; -} +class HloExecutionProfileTest : public HloTestBase {}; TEST_F(HloExecutionProfileTest, Basic) { - std::unique_ptr hlo_module = CreateNewModule(); - - HloComputation::Builder builder(TestName()); + auto hlo_module = ParseHloString(R"( + HloModule test_module + ENTRY entry_computation { + lhs = f32[30,30]{1,0} parameter(0) + rhs = f32[30,30]{1,0} parameter(1) + add = f32[30,30]{1,0} add(lhs, rhs) + ROOT dot = f32[30,30]{1,0} dot(lhs, add), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })") + .ValueOrDie(); + const HloInstruction* dot_instruction = + hlo_module->entry_computation()->root_instruction(); + const HloInstruction* add_instruction = dot_instruction->operand(1); Shape shape = ShapeUtil::MakeShape(F32, {30, 30}); - HloInstruction* param_lhs = - builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); - HloInstruction* param_rhs = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); - HloInstruction* add_instruction = - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kAdd, param_lhs, param_rhs)); - HloInstruction* dot_instruction = - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, param_lhs, add_instruction)); - - hlo_module->AddEntryComputation(builder.Build()); auto shape_size_function = [&](const Shape& shape) { const int64 pointer_size = 8; @@ -84,20 +64,12 @@ TEST_F(HloExecutionProfileTest, Basic) { execution_profile.SetCyclesTakenBy(add_instruction, add_cycles); execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles); - string rendered_profile = execution_profile.ToString( - backend().default_stream_executor()->GetDeviceDescription()); - std::vector> lines_and_words = - SplitIntoLinesAndWords(rendered_profile); - ASSERT_EQ(lines_and_words.size(), 8); - - const std::vector& line_2 = lines_and_words[2]; - const std::vector& line_3 = lines_and_words[3]; - - EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles)); - EXPECT_EQ(line_2[kInstructionNameIndex], '%' + dot_instruction->name()); - - EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles)); - EXPECT_EQ(line_3[kInstructionNameIndex], '%' + add_instruction->name()); + EXPECT_THAT(execution_profile.ToString( + backend().default_stream_executor()->GetDeviceDescription()), + AllOf(ContainsRegex(StrCat(dot_cycles, R"(\b.*%)", + dot_instruction->name())), + ContainsRegex(StrCat(add_cycles, R"(\b.*%)", + add_instruction->name())))); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index b6b03876725e4d0db818e0bbc3738896f0c0e66e..28fc6c4209bcc14d890f28d6a9935c55b76586c0 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -28,6 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -321,13 +323,11 @@ optional MatchTrivialComputation(const HloComputation* computation) { class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, - const DebugOptions& debug_options, bool show_metadata, - bool show_backend_config, const HloExecutionProfile* profile, - NodeFilter filter) + const DebugOptions& debug_options, bool show_backend_config, + const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), label_(std::string(label)), debug_options_(debug_options), - show_metadata_(show_metadata), show_backend_config_(show_backend_config), profile_(profile), filter_(std::move(filter)) {} @@ -395,7 +395,6 @@ class HloDotDumper { const HloComputation* computation_; // never null const string label_; // overall name for the graph const DebugOptions& debug_options_; - const bool show_metadata_; const bool show_backend_config_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -430,7 +429,8 @@ class HloDotDumper { // When coloring by sharding information, we track the sharding string // representation to color association, by round-robin the color schemes. - std::unordered_map sharding_colors_; + std::unordered_map + sharding_colors_; int64 next_shard_color_ = 0; }; @@ -592,15 +592,26 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, const HloInstruction* parent_instr) { VLOG(2) << "Dumping subcomputation " << subcomp->name(); - const char* computation_fmt = R"(subgraph %s { -%s -label = <%s>; -labelloc = t; -tooltip = " "; -%s -} // %s + // Add an edge from the subcomputation to its parent node. If subcomp + // belongs to a fusion node, it's drawn in place of the fusion instruction, + // so there's no need to link those. + if (parent_instr->opcode() != HloOpcode::kFusion) { + const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); + VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() + << " as " << next_edge_id_; + edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); + const char* edge_fmt = + R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; + edges_.push_back(Printf( + edge_fmt, InstructionId(from), InstructionId(parent_instr), + SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); + } -)"; + // Have we already dumped this subcomputation? If so, generating the edge + // linking it and parent_instr is all we want to do in this function. + if (cluster_ids_.find(subcomp) != cluster_ids_.end()) { + return ""; + } cluster_ids_[subcomp] = next_cluster_id_++; @@ -647,25 +658,16 @@ tooltip = " "; string comp_body = DumpComputation(subcomp); - // Add an edge from the subcomputation to its parent node. If subcomp - // belongs to a fusion node, it's drawn in place of the fusion instruction, - // so there's no need to link those. - if (parent_instr->opcode() != HloOpcode::kFusion) { - const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); - VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() - << " as " << next_edge_id_; - edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); - const char* edge_fmt = - R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back(Printf( - edge_fmt, InstructionId(from), InstructionId(parent_instr), - SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); - } - - string computation = - Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + const char* computation_fmt = R"(subgraph %s { +%s +label = <%s>; +labelloc = t; +tooltip = " "; +%s +} // %s - return computation; +)"; + return Printf(computation_fmt, id, style, subcomp_label, comp_body, id); } string HloDotDumper::DumpComputation(const HloComputation* comp) { @@ -723,11 +725,25 @@ string HloDotDumper::DumpRootTag() { to_id, node_body, node_shape, NodeColorAttributes(color)); } +static const HloConstantInstruction* TryGetFusionParameterConstant( + const HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) { + return nullptr; + } + const HloInstruction* fusion = instr->parent()->FusionInstruction(); + const HloInstruction* operand = fusion->operand(instr->parameter_number()); + return DynCast(operand); +} + bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // If a node: // - // - is a tuple-shaped parameter, - // - is not a parameter to a fusion node, + // - is a parameter of a fusion node which is bound to a constant, + // + // or + // + // - is a tuple-shaped parameter, and + // - is not a parameter to a fusion node, and // - has at least kMinUsersToOmit users shown, and // - all of the shown users are get-tuple-elements, // @@ -735,6 +751,9 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // // This helps us handle the common case where a while loop body has one big // tuple-shaped parameter. + if (TryGetFusionParameterConstant(instr) != nullptr) { + return true; + } const int kMinUsersToOmit = 3; return instr->opcode() == HloOpcode::kParameter && ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && @@ -791,22 +810,22 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } // Build the text that will be displayed inside the node. string node_body = node_label; - for (const string& s : {trivial_subcomputation, node_metadata, - node_backend_config, extra_info, inlined_constants}) { + for (const string& s : {trivial_subcomputation, node_backend_config, + extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } } - return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" + return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" "\n", - InstructionId(instr), node_body, node_shape, + InstructionId(instr), node_body, node_shape, node_metadata, NodeColorAttributes(color)); } string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { - auto stringify_constant = [](const HloInstruction* constant) { + auto stringify_constant = [](const HloConstantInstruction* constant) { const auto& shape = constant->shape(); // If the shape has a dimension of size zero, print it as e.g. @@ -841,29 +860,26 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( ShapeUtil::HumanString(constant->shape())); }; - // Special case: If instr is a parameter to a fusion node, check whether the - // corresponding operand to the fusion node is a constant. - if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { - const HloInstruction* fusion = instr->parent()->FusionInstruction(); - const HloInstruction* operand = fusion->operand(instr->parameter_number()); - if (operand->opcode() != HloOpcode::kConstant) { - return ""; - } - return StrCat("constant ", stringify_constant(operand)); - } - std::vector lines; for (int64 i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); + const auto* constant_operand = DynCast(operand); optional operand_str; - if (operand->opcode() == HloOpcode::kConstant) { - operand_str = stringify_constant(operand); + if (constant_operand != nullptr) { + operand_str = stringify_constant(constant_operand); } else if (ShouldMergeIntoUsers(operand)) { - // Special case: If the operand is a parameter, use its parameter number - // rather than its name, because that's generally how people think of the - // node. + // Special case: If the operand is a parameter to a fusion node and it + // always has a constant value, display it like a regular constant. + // + // For other parameters, use the parameter number rather than the proper + // name, because that's generally how people think of the node. if (operand->opcode() == HloOpcode::kParameter) { - operand_str = Printf("Parameter %lld", operand->parameter_number()); + if (const HloConstantInstruction* constant = + TryGetFusionParameterConstant(operand)) { + operand_str = stringify_constant(constant); + } else { + operand_str = Printf("Parameter %lld", operand->parameter_number()); + } } else { operand_str = operand->name(); } @@ -885,24 +901,26 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { if (!instr->has_sharding()) { return kDashedBorder; } - string shard_str = instr->sharding().ToString(); - auto it = sharding_colors_.find(shard_str); + auto it = sharding_colors_.find(instr->sharding()); if (it != sharding_colors_.end()) { return it->second; } ColorScheme color = static_cast( kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); - sharding_colors_.emplace(shard_str, color); + sharding_colors_.emplace(instr->sharding(), color); return color; } const auto kParameterColor = kOrange; // Special case: If this instruction has a parameter merged into it, paint it - // the same color as a parameter. + // the same color as a parameter. Unless the merged-in parameter is a + // parameter to a fusion node that is bound to a constant -- these aren't + // "real" parameters from the user's perspective. if (std::any_of(instr->operands().begin(), instr->operands().end(), [&](const HloInstruction* operand) { return operand->opcode() == HloOpcode::kParameter && - ShouldMergeIntoUsers(operand); + ShouldMergeIntoUsers(operand) && + TryGetFusionParameterConstant(operand) == nullptr; })) { return kParameterColor; } @@ -925,6 +943,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: @@ -932,6 +951,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -963,6 +983,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcast: case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: @@ -974,7 +995,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { } return kGreen; case HloOpcode::kConcatenate: - case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kPad: @@ -996,6 +1016,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kWhite; } return kGreen; + case HloOpcode::kCopy: + // Emphasize copy nodes, which are either physical transposes (and thus + // significant), or copies of read-only buffers (and thus dead weight). + return kGreen; case HloOpcode::kConvolution: case HloOpcode::kDot: case HloOpcode::kFft: @@ -1011,6 +1035,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: return kPurple; + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kMap: return kGray; @@ -1066,10 +1091,6 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { } string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { - if (!show_metadata_) { - return ""; - } - std::vector lines; if (!instr->metadata().op_name().empty()) { lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); @@ -1089,11 +1110,11 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeBackendConfig( const HloInstruction* instr) { - if (!show_backend_config_ || instr->backend_config().empty()) { + if (!show_backend_config_ || instr->raw_backend_config_string().empty()) { return ""; } - return StrCat("backend_config=\"", instr->backend_config(), "\""); + return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\""); } string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { @@ -1102,7 +1123,8 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { // Get the instruction's extra attributes excluding the names of its // subcomputations, since those are drawn explicitly in the graph. for (const auto& line : instr->ExtraAttributesToString( - HloPrintOptions().set_print_subcomputation_references(false))) { + HloPrintOptions().set_print_subcomputation_mode( + HloPrintOptions::PrintSubcomputationMode::kOff))) { lines.push_back(HtmlLikeStringSanitize(line)); } @@ -1151,6 +1173,20 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { return Join(lines, "
"); } +// Gets the total number of array elements in the given shape. For tuples, this +// is the sum of all the sizes of all of the array elements recursively in the +// tuple. +static int64 TotalElementsInShape(const Shape& shape) { + int64 elems = 0; + ShapeUtil::ForEachSubshape( + shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + elems += ShapeUtil::ElementsIn(subshape); + } + }); + return elems; +} + void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { @@ -1170,9 +1206,16 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { } else if (control_edge) { edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\""; } - const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)"; + + // We print "small" arrays using a hollow arrowhead and "large" arrays using + // a filled arrowhead. For now, we use an arbitrary cutoff for what "big" + // means. + bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; + + const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), - from->name(), to->name(), edge_label)); + (is_big_array ? "normal" : "empty"), from->name(), + to->name(), edge_label)); }; // Add edges from instr's operands to instr. Parameters within fusion @@ -1422,7 +1465,7 @@ string ExportGraph(const string& graph, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, - bool show_metadata, bool show_backend_config) { + bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; string graph; if (debug_options.xla_hlo_dump_as_graphdef()) { @@ -1433,8 +1476,8 @@ string DumpGraph(const HloComputation& computation, const string& label, graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { graph = - HloDotDumper(&computation, label, debug_options, show_metadata, - show_backend_config, hlo_execution_profile, NodeFilter()) + HloDotDumper(&computation, label, debug_options, show_backend_config, + hlo_execution_profile, NodeFilter()) .Dump(); graph_kind = GraphRendererInterface::DOT_GRAPH; } @@ -1446,15 +1489,15 @@ string DumpGraph(const HloComputation& computation, const string& label, } string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata, bool show_backend_config) { + bool show_backend_config) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); - string graph = HloDotDumper(node.parent(), label, debug_options, - show_metadata, show_backend_config, - /*profile=*/nullptr, filter) - .Dump(); + string graph = + HloDotDumper(node.parent(), label, debug_options, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index fc8e1468aca9c2edbc22c30a41a1be8b32a1feca..0b11f34abb7f0d937a24d11f4dc5d2d6a0aae6e7 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_metadata = false, bool show_backend_config = false); + bool show_backend_config = false); // Like DumpGraph, but renders only nodes "near" the given node in the graph. // @@ -64,7 +64,6 @@ string DumpGraph(const HloComputation& computation, const string& label, // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata = false, bool show_backend_config = false); // Dumps the HloModule::ToString() as a file into the provided directory path diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 8e52d926d85f1ce6fabeb2dedd2f8e0fe0c2051d..68f41a1cbb4db228f5dcf8b4a6130f05e81262a8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -121,7 +121,7 @@ TEST(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(-42))); - instruction->set_name("i_am_a_constant_root_instruction"); + instruction->SetAndSanitizeName("i_am_a_constant_root_instruction"); HloModuleConfig config; HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 857cd39adb8d320ce1ebe9f718e82596b3757889..9e9bf6361d798012c1c783b4ad06582d83053ae6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -27,7 +27,9 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -37,9 +39,11 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -51,24 +55,169 @@ using ::tensorflow::strings::StrCat; /* static */ StatusOr> HloInstruction::CreateFromProto( - HloModule* module, const HloInstructionProto& proto, + const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, const tensorflow::gtl::FlatMap& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); - auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); - for (const int64 operand_id : proto.operand_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) - << "No instruction with id " << operand_id; - instruction->AppendOperand(instruction_map.at(operand_id)); - } - for (const int64 predecessor_id : proto.control_predecessor_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) - << "No instruction with id " << predecessor_id; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) - ->AddControlDependencyTo(instruction.get())); + std::unique_ptr instruction; + const auto operands = [&instruction_map, &proto](int index) { + return instruction_map.at(proto.operand_ids(index)); + }; + const auto computations = [&computation_map, &proto](int index) { + return computation_map.at(proto.called_computation_ids(index)); + }; + switch (opcode) { + // Ops migrated to subclasses. + case HloOpcode::kBatchNormTraining: + CHECK_EQ(proto.operand_ids_size(), 3); + instruction = CreateBatchNormTraining( + proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(), + proto.feature_index()); + break; + case HloOpcode::kBatchNormInference: + CHECK_EQ(proto.operand_ids_size(), 5); + instruction = CreateBatchNormInference( + proto.shape(), operands(0), operands(1), operands(2), operands(3), + operands(4), proto.epsilon(), proto.feature_index()); + break; + case HloOpcode::kBatchNormGrad: + CHECK_EQ(proto.operand_ids_size(), 5); + instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1), + operands(2), operands(3), operands(4), + proto.epsilon(), proto.feature_index()); + break; + case HloOpcode::kFft: { + CHECK_EQ(proto.operand_ids_size(), 1); + std::vector fft_length(proto.fft_length().begin(), + proto.fft_length().end()); + instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), + tensorflow::gtl::ArraySlice(fft_length)); + break; + } + case HloOpcode::kSend: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateSend(operands(0), proto.channel_id()); + break; + case HloOpcode::kSendDone: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateSendDone(operands(0)); + break; + case HloOpcode::kRecv: + CHECK_EQ(proto.operand_ids_size(), 0); + instruction = + CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id()); + break; + case HloOpcode::kRecvDone: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateRecvDone(operands(0)); + break; + case HloOpcode::kReverse: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateReverse(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kConcatenate: { + CHECK_EQ(proto.dimensions_size(), 1); + std::vector concat_operands(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + concat_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateConcatenate(proto.shape(), concat_operands, + proto.dimensions(0)); + break; + } + case HloOpcode::kReduce: + CHECK_EQ(proto.operand_ids_size(), 2); + CHECK_EQ(proto.called_computation_ids_size(), 1); + instruction = CreateReduce(proto.shape(), operands(0), operands(1), + std::vector(proto.dimensions().begin(), + proto.dimensions().end()), + computations(0)); + break; + case HloOpcode::kTranspose: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = + CreateTranspose(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kBroadcast: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = + CreateBroadcast(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kMap: { + CHECK_EQ(proto.called_computation_ids_size(), 1); + std::vector map_operands(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + map_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateMap(proto.shape(), map_operands, computations(0)); + break; + } + case HloOpcode::kSlice: { + CHECK_EQ(proto.operand_ids_size(), 1); + std::vector slice_starts, slice_limits, slice_strides; + for (const HloInstructionProto::SliceDimensions& slice_dimensions : + proto.slice_dimensions()) { + slice_starts.push_back(slice_dimensions.start()); + slice_limits.push_back(slice_dimensions.limit()); + slice_strides.push_back(slice_dimensions.stride()); + } + instruction = CreateSlice(proto.shape(), operands(0), slice_starts, + slice_limits, slice_strides); + break; + } + case HloOpcode::kConstant: { + CHECK(proto.has_literal()); + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateConstant(std::move(literal)); + break; + } + case HloOpcode::kTrace: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Trace instruction should have 1 operand but sees " + << proto.operand_ids_size(); + CHECK(proto.has_literal()); + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + break; + } + default: { + instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + for (const int64 operand_id : proto.operand_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) + << "No instruction with id " << operand_id; + instruction->AppendOperand(instruction_map.at(operand_id)); + } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) + ->AddControlDependencyTo(instruction.get())); + } + if (instruction->opcode() != HloOpcode::kFusion) { + for (const int64 computation_id : proto.called_computation_ids()) { + TF_RET_CHECK(ContainsKey(computation_map, computation_id)) + << "No computation with id " << computation_id; + instruction->called_computations_.push_back( + computation_map.at(computation_id)); + } + } + break; + } } // In the proto, fused computations are held exclusively within the @@ -89,37 +238,16 @@ StatusOr> HloInstruction::CreateFromProto( << "No fusion computation with id " << fusion_id; fused_computation->SetFusionInstruction(instruction.get()); instruction->called_computations_.push_back(fused_computation); - } else { - for (const int64 computation_id : proto.called_computation_ids()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_id)) - << "No computation with id " << computation_id; - instruction->called_computations_.push_back( - computation_map.at(computation_id)); - } - } - - if (instruction->opcode() == HloOpcode::kTrace) { - TF_RET_CHECK(instruction->operands().size() == 1) - << "Trace instruction should have 1 operand but sees " - << instruction->operands().size(); - instruction->mutable_operand(0)->set_tracing(instruction.get()); } TF_RET_CHECK(!proto.name().empty()); - instruction->name_ = proto.name(); + instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); - instruction->set_backend_config(proto.backend_config()); - if (proto.has_literal()) { - TF_ASSIGN_OR_RETURN(instruction->literal_, - Literal::CreateFromProto(proto.literal())); - } + instruction->backend_config_ = proto.backend_config(); instruction->parameter_number_ = proto.parameter_number(); instruction->tuple_index_ = proto.tuple_index(); - for (int64 dimension : proto.dimensions()) { - instruction->dimensions_.push_back(dimension); - } if (proto.has_window()) { instruction->window_ = MakeUnique(proto.window()); } @@ -132,12 +260,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->dot_dimension_numbers_ = MakeUnique(proto.dot_dimension_numbers()); } - for (const HloInstructionProto::SliceDimensions& slice_dimensions : - proto.slice_dimensions()) { - instruction->slice_starts_.push_back(slice_dimensions.start()); - instruction->slice_limits_.push_back(slice_dimensions.limit()); - instruction->slice_strides_.push_back(slice_dimensions.stride()); - } + instruction->exponent_bits_ = proto.exponent_bits(); instruction->mantissa_bits_ = proto.mantissa_bits(); for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { @@ -149,16 +272,9 @@ StatusOr> HloInstruction::CreateFromProto( } instruction->outfeed_config_ = proto.outfeed_config(); instruction->distribution_ = proto.distribution(); - instruction->epsilon_ = proto.epsilon(); - instruction->feature_index_ = proto.feature_index(); - instruction->channel_id_ = proto.channel_id(); instruction->infeed_config_ = proto.infeed_config(); instruction->custom_call_target_ = proto.custom_call_target(); instruction->outfeed_shape_ = proto.outfeed_shape(); - instruction->fft_type_ = proto.fft_type(); - for (int64 fft_len : proto.fft_length()) { - instruction->fft_length_.push_back(fft_len); - } if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -185,26 +301,18 @@ StatusOr> HloInstruction::CreateFromProto( auto instruction = WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); instruction->parameter_number_ = parameter_number; - instruction->name_ = name; + instruction->SetAndSanitizeName(name); return instruction; } /* static */ std::unique_ptr HloInstruction::CreateTrace( const string& tag, HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); - instruction->operands_.push_back(operand); - instruction->literal_ = Literal::CreateR1U8(tag); - operand->set_tracing(instruction.get()); - return instruction; + return MakeUnique(tag, operand); } /* static */ std::unique_ptr HloInstruction::CreateConstant( std::unique_ptr literal) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape())); - instruction->literal_ = std::move(literal); - return instruction; + return MakeUnique(std::move(literal)); } /* static */ std::unique_ptr @@ -256,11 +364,14 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kClz: + case HloOpcode::kDomain: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -339,13 +450,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* map_computation, tensorflow::gtl::ArraySlice static_operands) { - CHECK(static_operands.empty()) << "static_operands not yet supported"; - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->called_computations_.push_back(map_computation); - return instruction; + return MakeUnique(shape, operands, map_computation, + static_operands); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( @@ -371,11 +477,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, tensorflow::gtl::ArraySlice fft_length) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape)); - instruction->AppendOperand(operand); - instruction->fft_type_ = fft_type; - instruction->fft_length_.assign(fft_length.begin(), fft_length.end()); - return instruction; + return MakeUnique(shape, operand, fft_type, fft_length); } /* static */ std::unique_ptr HloInstruction::CreateDot( @@ -418,8 +520,20 @@ HloInstruction::CreateReducePrecision(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands) { - return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id) { + // TODO(b/79737069): Remove the CHECK when supported. + CHECK(replica_group_ids.empty()); + CHECK(!channel_id.has_value()); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->called_computations_.push_back(reduce_computation); + return instruction; } /* static */ std::unique_ptr HloInstruction::CreateInfeed( @@ -445,56 +559,44 @@ HloInstruction::CreateCrossReplicaSum( /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, int64 channel_id) { - // Send instruction produces a tuple of {aliased operand, U32 context}. - Shape output_shape = ShapeUtil::MakeTupleShape( - {operand->shape(), ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = channel_id; - return instruction; + return MakeUnique(operand, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kSend) + auto send_operand = DynCast(operand); + CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - auto instruction = WrapUnique( - new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil())); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(send_operand); } /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, int64 channel_id) { - // Recv instruction produces a tuple of {receive buffer, U32 context}. - Shape output_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape)); - instruction->channel_id_ = channel_id; - return instruction; + return MakeUnique(shape, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kRecv) + auto recv_operand = DynCast(operand); + CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(recv_operand); } /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); + return MakeUnique(shape, operand, dimensions); +} + +/* static */ std::unique_ptr +HloInstruction::CreateGenerateToken( + tensorflow::gtl::ArraySlice operands) { + auto instruction = WrapUnique(new HloInstruction( + HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape())); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } return instruction; } @@ -531,18 +633,8 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape)); - instruction->AppendOperand(operand); - instruction->slice_starts_.assign(start_indices.begin(), start_indices.end()); - instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end()); - instruction->slice_strides_.assign(strides.begin(), strides.end()); - // For backward compatibility with old serialized computations: if there are - // no strides, assume all strides are 1. - // TODO(b/63317920): remove this code. - if (instruction->slice_strides_.empty()) { - instruction->slice_strides_ = std::vector(start_indices.size(), 1LL); - } - return instruction; + return MakeUnique(shape, operand, start_indices, + limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( @@ -573,13 +665,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConcatenate( const Shape& shape, tensorflow::gtl::ArraySlice operands, int64 dimension) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->dimensions_.push_back(dimension); - return instruction; + return MakeUnique(shape, operands, dimension); } /* static */ std::unique_ptr HloInstruction::CreateConvert( @@ -602,13 +688,8 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, const Shape& shape, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape)); - instruction->AppendOperand(arg); - instruction->AppendOperand(init_value); - instruction->dimensions_.assign(dimensions_to_reduce.begin(), - dimensions_to_reduce.end()); - instruction->called_computations_.push_back(reduce_computation); - return instruction; + return MakeUnique( + shape, arg, init_value, dimensions_to_reduce, reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( @@ -629,14 +710,8 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(offset); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique( + shape, operand, scale, offset, epsilon, feature_index); } /* static */ std::unique_ptr @@ -644,16 +719,8 @@ HloInstruction::CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(offset); - instruction->AppendOperand(mean); - instruction->AppendOperand(variance); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique( + shape, operand, scale, offset, mean, variance, epsilon, feature_index); } /* static */ std::unique_ptr @@ -662,16 +729,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormGrad, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(mean); - instruction->AppendOperand(variance); - instruction->AppendOperand(grad_output); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique(shape, operand, scale, mean, + variance, grad_output, epsilon, + feature_index); } /* static */ std::unique_ptr @@ -694,12 +754,8 @@ HloInstruction::CreateSelectAndScatter( /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(broadcast_dimensions.begin(), - broadcast_dimensions.end()); - return instruction; + return MakeUnique(shape, operand, + broadcast_dimensions); } /* static */ std::unique_ptr @@ -778,19 +834,7 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - CHECK_EQ(shape.dimensions().size(), dimensions.size()); - CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); - CHECK(std::equal(operand->shape().dimensions().begin(), - operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())) - << "shape: " << ShapeUtil::HumanString(shape) - << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << Join(dimensions, ", ") << "}"; - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); - return instruction; + return MakeUnique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateFusion( @@ -819,6 +863,15 @@ HloInstruction::CreateBroadcastSequence( return instruction; } +void HloInstruction::set_single_sharding(const HloSharding& sharding) { + CHECK(!sharding.IsTuple()) << sharding; + if (ShapeUtil::IsTuple(shape())) { + set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape()))); + } else { + set_sharding(sharding); + } +} + void HloInstruction::SetupDerivedInstruction( HloInstruction* derived_instruction) const { if (sharding_ != nullptr) { @@ -1112,7 +1165,7 @@ RandomDistribution HloInstruction::random_distribution() const { return distribution_; } -bool HloInstruction::HasSideEffect() const { +bool HloInstruction::HasSideEffectNoRecurse() const { switch (opcode_) { case HloOpcode::kSend: case HloOpcode::kSendDone: @@ -1124,16 +1177,22 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kTrace: case HloOpcode::kHostCompute: return true; - default: { - // Check if any of the called computations has a side effect. - for (const auto& computation : called_computations()) { - if (computation->HasSideEffect()) { - return true; - } - } + default: return false; + } +} + +bool HloInstruction::HasSideEffect() const { + if (HasSideEffectNoRecurse()) { + return true; + } + // Check if any of the called computations has a side effect. + for (const auto& computation : called_computations()) { + if (computation->HasSideEffect()) { + return true; } } + return false; } /* static */ std::unique_ptr HloInstruction::CreateCall( @@ -1217,25 +1276,53 @@ bool HloInstruction::HasSideEffect() const { return gather_dim_numbers; } +/* static */ std::unique_ptr HloInstruction::CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); + instruction->operand_side_metadata_ = std::move(operand_side_metadata); + instruction->user_side_metadata_ = std::move(user_side_metadata); + instruction->AppendOperand(operand); + return instruction; +} + std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, - HloModule* module, CloneMap* clone_map) const { + HloCloneContext* context) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { VLOG(3) << " %" << new_operand->name(); } - if (module == nullptr) { - module = GetModule(); - } std::unique_ptr clone; - // Explicitly call the factory for the instruction type. This is more robust // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. switch (opcode_) { + // Ops migrated to subclasses. + // TODO(b/80131774): Remove this switch when migration is complete. + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + clone = CloneWithNewOperandsImpl(shape, new_operands, context); + break; // Unary ops. case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -1245,10 +1332,12 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -1292,23 +1381,24 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[2]); break; // Other supported ops. - case HloOpcode::kBroadcast: - CHECK_EQ(new_operands.size(), 1); - clone = CreateBroadcast(shape, new_operands[0], dimensions_); - break; case HloOpcode::kCall: clone = CreateCall(shape, new_operands, to_apply()); break; case HloOpcode::kCustomCall: clone = CreateCustomCall(shape, new_operands, custom_call_target_); + if (window_ != nullptr) { + clone->window_ = MakeUnique(*window_); + } + if (convolution_dimension_numbers_ != nullptr) { + clone->convolution_dimension_numbers_ = + MakeUnique( + *convolution_dimension_numbers_); + } break; case HloOpcode::kHostCompute: clone = CreateHostCompute(shape, new_operands, channel_name_, cost_estimate_ns_); break; - case HloOpcode::kConcatenate: - clone = CreateConcatenate(shape, new_operands, dimensions(0)); - break; case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); clone = CreateConvert(shape, new_operands[0]); @@ -1332,30 +1422,18 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateDot(shape, new_operands[0], new_operands[1], *dot_dimension_numbers_); break; - case HloOpcode::kFft: - CHECK_EQ(new_operands.size(), 1); - clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); - break; case HloOpcode::kCrossReplicaSum: - clone = CreateCrossReplicaSum(shape, new_operands); + clone = CreateCrossReplicaSum(shape, new_operands, to_apply()); break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); clone = CreateGetTupleElement(shape, new_operands[0], tuple_index()); break; - case HloOpcode::kMap: - clone = CreateMap(shape, new_operands, to_apply()); - break; case HloOpcode::kPad: CHECK_EQ(new_operands.size(), 2); clone = CreatePad(shape, new_operands[0], new_operands[1], *padding_config_); break; - case HloOpcode::kReduce: - CHECK_EQ(new_operands.size(), 2); - clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, - to_apply()); - break; case HloOpcode::kReduceWindow: CHECK_EQ(new_operands.size(), 2); clone = CreateReduceWindow(shape, new_operands[0], new_operands[1], @@ -1367,10 +1445,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CreateSelectAndScatter(shape, new_operands[0], select(), *window_, new_operands[1], new_operands[2], scatter()); break; - case HloOpcode::kReverse: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReverse(shape, new_operands[0], dimensions_); - break; case HloOpcode::kRng: clone = CreateRng(shape, distribution_, new_operands); break; @@ -1378,11 +1452,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); break; - case HloOpcode::kSlice: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_, - slice_strides_); - break; case HloOpcode::kDynamicSlice: clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); @@ -1392,10 +1461,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], new_operands[2]); break; - case HloOpcode::kTranspose: - CHECK_EQ(new_operands.size(), 1); - clone = CreateTranspose(shape, new_operands[0], dimensions_); - break; case HloOpcode::kTuple: clone = CreateTuple(new_operands); *clone->mutable_shape() = shape; @@ -1405,13 +1470,17 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateWhile(shape, while_condition(), while_body(), new_operands[0]); break; - case HloOpcode::kConstant: - clone = CreateConstant(literal_->CloneToUnique()); - break; case HloOpcode::kFusion: { - CHECK_NE(module, nullptr); - auto new_fused_computation = module->AddEmbeddedComputation( - fused_instructions_computation()->Clone("clone", module, clone_map)); + HloModule* module = context != nullptr ? context->module() : GetModule(); + HloComputation* new_fused_computation = nullptr; + if (context != nullptr) { + new_fused_computation = + context->FindComputation(fused_instructions_computation()); + } + if (new_fused_computation == nullptr) { + new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", context)); + } clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), /*operands=*/new_operands, /*fusion_computation=*/new_fused_computation); @@ -1420,18 +1489,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kParameter: clone = CreateParameter(parameter_number_, shape, name_); break; - case HloOpcode::kBatchNormTraining: - CHECK_EQ(new_operands.size(), 3); - clone = - CreateBatchNormTraining(shape, new_operands[0], new_operands[1], - new_operands[2], epsilon(), feature_index()); - break; - case HloOpcode::kBatchNormInference: - CHECK_EQ(new_operands.size(), 5); - clone = CreateBatchNormInference( - shape, new_operands[0], new_operands[1], new_operands[2], - new_operands[3], new_operands[4], epsilon(), feature_index()); - break; case HloOpcode::kInfeed: CHECK_EQ(new_operands.size(), 0); clone = CreateInfeed(shape, infeed_config()); @@ -1440,49 +1497,37 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); break; - case HloOpcode::kBatchNormGrad: - CHECK_EQ(new_operands.size(), 5); - clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1], - new_operands[2], new_operands[3], - new_operands[4], epsilon(), feature_index()); - break; case HloOpcode::kConditional: CHECK_EQ(new_operands.size(), 3); clone = CreateConditional(shape, new_operands[0], new_operands[1], true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kSend: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSend(new_operands[0], channel_id()); - break; - case HloOpcode::kSendDone: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSendDone(new_operands[0]); - break; - case HloOpcode::kRecv: - CHECK_EQ(new_operands.size(), 0); - // The shape is a tuple, but CreateRecv() wants the raw data shape. - clone = - CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); - break; - case HloOpcode::kRecvDone: - CHECK_EQ(new_operands.size(), 1); - clone = CreateRecvDone(new_operands[0]); - break; case HloOpcode::kGather: CHECK_EQ(new_operands.size(), 2); clone = CreateGather(shape, new_operands[0], new_operands[1], *gather_dimension_numbers_, gather_window_bounds_); break; - case HloOpcode::kTrace: - LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); + case HloOpcode::kDomain: + CHECK_EQ(new_operands.size(), 1); + clone = + CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); + break; + case HloOpcode::kGenerateToken: + clone = CreateGenerateToken(new_operands); + break; } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); - clone->set_backend_config(backend_config()); - if (clone_map != nullptr) { - InsertOrDie(clone_map, this, clone.get()); + clone->set_raw_backend_config_string(backend_config_); + if (context != nullptr) { + context->MapInstruction(this, clone.get()); + clone->ReplaceCalledComputations([&](HloComputation* callee) { + return callee->parent() != context->module() + ? context->module()->DeepCloneComputation(callee, context) + : callee; + }); } return clone; } @@ -1490,9 +1535,9 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( HloInstruction::~HloInstruction() {} std::unique_ptr HloInstruction::Clone( - const string& suffix, HloModule* module, CloneMap* clone_map) const { + const string& suffix, HloCloneContext* context) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, module, clone_map); + CloneWithNewOperands(shape_, operands_, context); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1552,33 +1597,6 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const { return hlo; } -const Literal& HloInstruction::literal() const { - CHECK_EQ(HloOpcode::kConstant, opcode_); - return *literal_; -} - -bool HloInstruction::CanHaveDimensionsField() const { - return (opcode() == HloOpcode::kReverse || - opcode() == HloOpcode::kConcatenate || - opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast || - opcode() == HloOpcode::kTranspose); -} - -const std::vector& HloInstruction::dimensions() const { - CHECK(CanHaveDimensionsField()); - return dimensions_; -} - -int64 HloInstruction::dimensions(int64 index) const { - return dimensions()[index]; -} - -int64 HloInstruction::concatenate_dimension() const { - CHECK(opcode() == HloOpcode::kConcatenate); - CHECK_EQ(1, dimensions_.size()); - return dimensions(0); -} - int64 HloInstruction::tuple_index() const { CHECK_EQ(HloOpcode::kGetTupleElement, opcode_); return tuple_index_; @@ -1602,6 +1620,17 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { LOG(FATAL) << "target was not an operand: " << target->ToString(); } +HloInstruction::InstructionVector HloInstruction::unique_operands() const { + InstructionVector unique; + tensorflow::gtl::FlatSet seen; + for (HloInstruction* operand : operands()) { + if (seen.insert(operand).second) { + unique.push_back(operand); + } + } + return unique; +} + Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); if (std::find(control_successors_.begin(), control_successors_.end(), @@ -1661,10 +1690,6 @@ void HloInstruction::AddUser(HloInstruction* user) { } } -bool HloInstruction::IsConstant() const { - return opcode_ == HloOpcode::kConstant; -} - bool HloInstruction::HasConstantOperand() const { for (const HloInstruction* operand : operands_) { if (operand->IsConstant()) { @@ -1677,26 +1702,29 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const { + eq_computations) const { // Perform opcode specific checks. switch (opcode()) { // The result of these instructions only depend upon their opcode and // operands. case HloOpcode::kAbs: case HloOpcode::kAtan2: - case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: case HloOpcode::kComplex: + case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: - case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: @@ -1704,6 +1732,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kAnd: case HloOpcode::kNot: case HloOpcode::kOr: @@ -1716,6 +1745,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: + case HloOpcode::kReshape: + case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: @@ -1733,32 +1764,14 @@ bool HloInstruction::IdenticalSlowPath( other.fused_instructions_computation()); // These opcodes have complex or special behavior so just return false. + case HloOpcode::kDomain: case HloOpcode::kRng: - case HloOpcode::kTrace: case HloOpcode::kWhile: + case HloOpcode::kGenerateToken: return false; case HloOpcode::kParameter: - return parameter_number() == other.parameter_number() && - // Check the shape too because `this` and `other` may be in - // different HloComputations. - eq_shapes(shape(), other.shape()); - - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: - case HloOpcode::kBatchNormGrad: - return feature_index() == other.feature_index() && - epsilon() == other.epsilon(); - - // A constant is defined by the value in the literal. - case HloOpcode::kConstant: - return literal() == other.literal(); - - // A convert result is determined by the primitive type that the operand is - // converted into. - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - return shape().element_type() == other.shape().element_type(); + return parameter_number() == other.parameter_number(); // A reduce-precision operation is determined by the bit sizes. case HloOpcode::kReducePrecision: @@ -1781,16 +1794,6 @@ bool HloInstruction::IdenticalSlowPath( other.gather_dimension_numbers()) && gather_window_bounds() == other.gather_window_bounds(); - // FFT has various types & lengths. - case HloOpcode::kFft: - return fft_type() == other.fft_type() && - fft_length() == other.fft_length(); - - // Reduction results are determined by the reduction dimension and the - // reduction computation. - case HloOpcode::kReduce: - return dimensions() == other.dimensions() && - eq_computations(to_apply(), other.to_apply()); case HloOpcode::kReduceWindow: return eq_computations(to_apply(), other.to_apply()) && protobuf_util::ProtobufEquals(window(), other.window()); @@ -1802,43 +1805,30 @@ bool HloInstruction::IdenticalSlowPath( eq_computations(scatter(), other.scatter()) && protobuf_util::ProtobufEquals(window(), other.window()); - case HloOpcode::kReshape: - return eq_shapes(shape(), other.shape()); - - // Transpose result is determined by the final shape and the permutation. - case HloOpcode::kTranspose: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); - // Remaining instructions with special values. - case HloOpcode::kBitcast: - return eq_shapes(shape(), other.shape()); - case HloOpcode::kBroadcast: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); - case HloOpcode::kConcatenate: - return dimensions() == other.dimensions(); case HloOpcode::kGetTupleElement: return tuple_index() == other.tuple_index(); case HloOpcode::kPad: return protobuf_util::ProtobufEquals(padding_config(), other.padding_config()); - case HloOpcode::kSlice: - return slice_starts_ == other.slice_starts_ && - slice_limits_ == other.slice_limits_ && - slice_strides_ == other.slice_strides_; - case HloOpcode::kDynamicSlice: - return eq_shapes(shape(), other.shape()) && - dynamic_slice_sizes_ == other.dynamic_slice_sizes_; - case HloOpcode::kDynamicUpdateSlice: - return eq_shapes(shape(), other.shape()); case HloOpcode::kCall: - case HloOpcode::kMap: + case HloOpcode::kCrossReplicaSum: return eq_computations(to_apply(), other.to_apply()); case HloOpcode::kCustomCall: + if ((window_ == nullptr) != (other.window_ == nullptr) || + (window_ != nullptr && + !protobuf_util::ProtobufEquals(window(), other.window()))) { + return false; + } + if ((convolution_dimension_numbers_ == nullptr) != + (other.convolution_dimension_numbers_ == nullptr) || + (convolution_dimension_numbers_ != nullptr && + !protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + other.convolution_dimension_numbers()))) { + return false; + } return custom_call_target_ == other.custom_call_target_; - case HloOpcode::kReverse: - return dimensions() == other.dimensions(); case HloOpcode::kConditional: return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); @@ -1847,21 +1837,31 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: case HloOpcode::kHostCompute: return false; - } -} -bool HloInstruction::IsRank2Transpose() const { - return (opcode_ == HloOpcode::kTranspose) && - dimensions_ == std::vector({1, 0}) && - shape_.dimensions_size() == 2 && - std::equal(shape_.dimensions().begin(), shape_.dimensions().end(), - operands_[0]->shape_.dimensions().rbegin()); + // Ops migrated to subclasses should never come to this line. + // TODO(b/80131774): Remove this switch when migration is complete. + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + LOG(FATAL) << "Base class impl called for opcode with subclass: " + << opcode(); + } } void HloInstruction::RemoveUser(HloInstruction* user) { @@ -1967,6 +1967,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -2098,51 +2099,61 @@ string PrintName(const string& name, const HloPrintOptions& options) { } // namespace string HloInstruction::ToString(const HloPrintOptions& options) const { - string result = - StrCat(PrintName(name(), options), " = ", - ShapeUtil::HumanStringWithLayout(shape()), " ", - HloOpcodeString(opcode()), "(", OperandsToString(options), ")"); + CanonicalNameMap new_map; + return ToStringWithCanonicalNameMap(options, &new_map); +} + +string HloInstruction::ToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + string result = ""; + + // Logic to print the instruction name (e.g. "%foo = "). + if (options.canonicalize_instruction_names()) { + if (options.is_in_nested_computation()) { + // If we are canonicalizing instruction names and this is a top-level + // HloInstruction::ToString() call, don't print an instruction name. + StrAppend(&result, + PrintName(canonical_name_map->LookupOrInsert(name()), options), + " = "); + } + } else { + StrAppend(&result, PrintName(name(), options), " = "); + } + + // Print opcode, operand(s) and shape. + StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ", + HloOpcodeString(opcode()), "(", + OperandsToStringWithCanonicalNameMap(options, canonical_name_map), + ")"); + + // Print additional attributes. If an instruction contains a subcomputation, + // the subcomputation is also printed here. for (const string& extra : ExtraAttributesToString(options)) { StrAppend(&result, ", ", extra); } + if (options.print_metadata() && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } - if (options.print_backend_config() && !backend_config().empty()) { - StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\""); + if (options.print_backend_config() && !backend_config_.empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\""); } return result; } string HloInstruction::OperandsToString(const HloPrintOptions& options) const { + CanonicalNameMap new_map; + return OperandsToStringWithCanonicalNameMap(options, &new_map); +} + +string HloInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { string operands; - if (opcode() == HloOpcode::kConstant) { - // For constants, show the actual value in place of an empty operand list. - if ((!ShapeUtil::IsTuple(shape()) && - ShapeUtil::ElementsIn(shape()) <= 10) || - options.print_large_constants()) { - // Literal::ToString emits multidimensional arrays over multiple - // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToString(); - std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector v = tensorflow::str_util::Split(tmp, ' '); - bool first = true; - // Concatenate elements in "v" with spaces separating them, but ignoring - // empty entries. - for (const auto& s : v) { - if (s.empty()) { - continue; - } - StrAppend(&operands, (first ? "" : " "), s); - first = false; - } - } else { - // Do not show large constants or tuples. - operands = "{...}"; - } - } else if (opcode() == HloOpcode::kParameter) { + if (opcode() == HloOpcode::kParameter) { StrAppend(&operands, parameter_number_); } else { tensorflow::gtl::ArraySlice slice(operands_); @@ -2156,7 +2167,14 @@ string HloInstruction::OperandsToString(const HloPrintOptions& options) const { if (options.print_operand_shape()) { str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); } - if (!options.compact_operands()) { + + // In a top-level HloInstruction::ToString() call, the operand name is not + // part of the canonical string. + if (options.canonicalize_instruction_names() && + options.is_in_nested_computation()) { + str.push_back(PrintName( + canonical_name_map->LookupOrInsert(operand->name()), options)); + } else if (!options.compact_operands()) { str.push_back(PrintName(operand->name(), options)); } StrAppend(out, Join(str, " ")); @@ -2171,13 +2189,11 @@ string HloInstruction::OperandsToString(const HloPrintOptions& options) const { std::vector HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { - std::vector extra; + std::vector extra = ExtraAttributesToStringImpl(options); + if (opcode() == HloOpcode::kFusion) { extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); } - if (CanHaveDimensionsField()) { - extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); - } if (window_ != nullptr && window_->dimensions_size() != 0) { extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); } @@ -2185,32 +2201,16 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); } - if (opcode() == HloOpcode::kSlice) { - std::vector bounds; - bounds.reserve(slice_starts_.size()); - const bool omit_stride = - std::all_of(slice_strides_.begin(), slice_strides_.end(), - [](int64 stride) { return stride == 1; }); - for (int i = 0; i < slice_starts_.size(); ++i) { - string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); - bounds.push_back(StrCat("[", slice_starts_[i], ":", slice_limits_[i], - stride_str, "]")); - } - extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); - } + if (opcode() == HloOpcode::kDynamicSlice) { extra.push_back( StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); } - if (opcode() == HloOpcode::kBatchNormTraining || - opcode() == HloOpcode::kBatchNormInference || - opcode() == HloOpcode::kBatchNormGrad) { - extra.push_back(StrCat("epsilon=", epsilon())); - extra.push_back(StrCat("feature_index=", feature_index())); - } if (convolution_dimension_numbers_ != nullptr) { - extra.push_back(ConvolutionDimensionNumbersToString()); + extra.push_back(StrCat( + "dim_labels=", + ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); } if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); @@ -2220,12 +2220,9 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); } - if (opcode() == HloOpcode::kFft) { - extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); - extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); - } - if (options.print_subcomputation_references()) { + if (options.print_subcomputation_mode() == + HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { extra.push_back( StrCat("condition=", PrintName(while_condition()->name(), options))); @@ -2242,7 +2239,8 @@ std::vector HloInstruction::ExtraAttributesToString( PrintName(false_computation()->name(), options))); } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || - opcode() == HloOpcode::kReduce) { + opcode() == HloOpcode::kReduce || + opcode() == HloOpcode::kCrossReplicaSum) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2253,11 +2251,44 @@ std::vector HloInstruction::ExtraAttributesToString( PrintName(computation->name(), options)); }))); } - } - - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || - opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { - extra.push_back(StrCat("channel_id=", channel_id_)); + } else if (options.print_subcomputation_mode() == + HloPrintOptions::PrintSubcomputationMode::kFullBodies) { + HloPrintOptions new_options = options; + new_options.set_is_in_nested_computation(true); + switch (opcode()) { + case HloOpcode::kWhile: + extra.push_back( + StrCat("condition=\n", while_condition()->ToString(new_options))); + extra.push_back(StrCat("body=\n", while_body()->ToString(new_options))); + break; + case HloOpcode::kSelectAndScatter: + extra.push_back(StrCat("select=\n", select()->ToString(new_options))); + extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options))); + break; + case HloOpcode::kConditional: + extra.push_back(StrCat("true_computation=\n", + true_computation()->ToString(new_options))); + extra.push_back(StrCat("false_computation=\n", + false_computation()->ToString(new_options))); + break; + case HloOpcode::kCall: + case HloOpcode::kMap: + case HloOpcode::kReduceWindow: + case HloOpcode::kReduce: + extra.push_back( + StrCat("to_apply=\n", to_apply()->ToString(new_options))); + break; + default: + if (!called_computations().empty()) { + extra.push_back( + StrCat("calls=\n", + Join(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, computation->ToString(new_options)); + }))); + } + break; + } } if (opcode() == HloOpcode::kGetTupleElement) { @@ -2290,9 +2321,13 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("exponent_bits=", exponent_bits_)); extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); } - + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", operand_side_metadata_->ToString(), + ", exit=", user_side_metadata_->ToString(), "}")); + } // By contract, we print the custom call target even if - // !options.print_subcomputation_references(), because the call target is not + // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. if (opcode() == HloOpcode::kCustomCall) { extra.push_back( @@ -2328,10 +2363,7 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; - proto.set_backend_config(backend_config()); - if (literal_ != nullptr) { - *proto.mutable_literal() = literal_->ToProto(); - } + proto.set_backend_config(backend_config_); proto.set_parameter_number(parameter_number_); if (opcode() == HloOpcode::kFusion) { proto.set_fusion_kind(xla::ToString(fusion_kind())); @@ -2344,9 +2376,6 @@ HloInstructionProto HloInstruction::ToProto() const { } proto.set_tuple_index(tuple_index_); - for (int64 dimension : dimensions_) { - proto.add_dimensions(dimension); - } if (window_ != nullptr) { *proto.mutable_window() = *window_; } @@ -2365,12 +2394,7 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_gather_window_bounds(bound); } } - for (int i = 0; i < slice_starts_.size(); ++i) { - auto* slice_dimension = proto.add_slice_dimensions(); - slice_dimension->set_start(slice_starts_[i]); - slice_dimension->set_limit(slice_limits_[i]); - slice_dimension->set_stride(slice_strides_[i]); - } + proto.set_exponent_bits(exponent_bits_); proto.set_mantissa_bits(mantissa_bits_); for (int64 slice_size : dynamic_slice_sizes_) { @@ -2383,15 +2407,12 @@ HloInstructionProto HloInstruction::ToProto() const { if (opcode() == HloOpcode::kRng) { proto.set_distribution(distribution_); } - proto.set_epsilon(epsilon_); - proto.set_feature_index(feature_index_); - proto.set_channel_id(channel_id_); proto.set_infeed_config(infeed_config_); proto.set_custom_call_target(custom_call_target_); *proto.mutable_outfeed_shape() = outfeed_shape_; - proto.set_fft_type(fft_type_); - for (int64 fft_len : fft_length_) { - proto.add_fft_length(fft_len); + + if (has_sharding()) { + *proto.mutable_sharding() = sharding().ToProto(); } proto.set_channel_name(channel_name_); @@ -2448,12 +2469,6 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { trace_instruction_ = trace_instruction; } -string HloInstruction::TracingTag() const { - CHECK_EQ(HloOpcode::kTrace, opcode()); - CHECK(literal_ != nullptr); - return literal_->GetR1U8AsString(); -} - bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } bool HloInstruction::IsFusable() const { @@ -2463,6 +2478,7 @@ bool HloInstruction::IsFusable() const { } // Some kinds of instructions don't make sense to fuse. switch (opcode_) { + case HloOpcode::kDomain: case HloOpcode::kParameter: return false; // Side effecting instrutions cannot be fused. @@ -2475,7 +2491,9 @@ HloComputation* HloInstruction::fused_instructions_computation() const { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(!called_computations_.empty()); auto* fused_instructions_computation = called_computations_.front(); - CHECK(fused_instructions_computation->IsFusionComputation()); + CHECK(fused_instructions_computation->IsFusionComputation()) + << "Computation " << fused_instructions_computation->name() + << " is not a fusion kind"; return fused_instructions_computation; } @@ -2614,6 +2632,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleNegate(this); case HloOpcode::kExp: return visitor->HandleExp(this); + case HloOpcode::kExpm1: + return visitor->HandleExpm1(this); case HloOpcode::kFloor: return visitor->HandleFloor(this); case HloOpcode::kCeil: @@ -2622,6 +2642,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleClz(this); case HloOpcode::kLog: return visitor->HandleLog(this); + case HloOpcode::kLog1p: + return visitor->HandleLog1p(this); case HloOpcode::kTanh: return visitor->HandleTanh(this); case HloOpcode::kCos: @@ -2686,6 +2708,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); case HloOpcode::kGather: return visitor->HandleGather(this); + case HloOpcode::kDomain: + return visitor->HandleDomain(this); + case HloOpcode::kGenerateToken: + return visitor->HandleGenerateToken(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2954,10 +2980,6 @@ bool HloInstruction::IsElementwiseBinary() const { bool HloInstruction::IsElementwise() const { switch (opcode_) { - // Nullary elementwise operations. - case HloOpcode::kConstant: - return true; - // Unary elementwise operations. case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -2968,10 +2990,12 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -3015,7 +3039,6 @@ bool HloInstruction::IsElementwise() const { // Other operations. case HloOpcode::kRng: - case HloOpcode::kMap: return true; case HloOpcode::kFusion: if (fusion_kind() != FusionKind::kLoop) { @@ -3036,7 +3059,7 @@ bool HloInstruction::IsElementwise() const { bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { CHECK(IsElementwise()); - return !ShapeUtil::Equal(shape(), operand(operand_idx)->shape()); + return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape()); } namespace { @@ -3271,42 +3294,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); } -StatusOr StringToRandomDistribution(const string& name) { - static std::unordered_map* map = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { - if (RandomDistribution_IsValid(i)) { - auto value = static_cast(i); - (*map)[RandomDistributionToString(value)] = value; - } - } - return map; - }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); - if (found == map->end()) { - return InvalidArgument("Unknown distribution"); - } - return found->second; -} - -std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { - return os << ToString(kind); -} - -string HloInstruction::ConvolutionDimensionNumbersToString() const { - string result; - if (convolution_dimension_numbers_ == nullptr) { - return result; - } - const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_; - // Show the given dimension labels in order of major to minor based on the - // shape's layout. - const auto append_dims = [&](const std::vector& dims, - const Shape& shape) { - CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - StrAppend(&result, Join(dims, "")); - }; - +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums) { // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". std::vector lhs_dims(2 + dnums.input_spatial_dimensions().size()); @@ -3330,19 +3319,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - result += "dim_labels="; - append_dims(lhs_dims, operand(0)->shape()); - result += "_"; - append_dims(rhs_dims, operand(1)->shape()); - result += "->"; - - // A convolution can be represented as a kConvolution HLO or as a CustomCall - // that returns a tuple, the first element of which is the result of the - // convolution. - Shape this_shape = - ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape(); - append_dims(output_dims, this_shape); - return result; + return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", + Join(output_dims, "")); } string HloInstruction::DotDimensionNumbersToString() const { @@ -3368,6 +3346,28 @@ string HloInstruction::DotDimensionNumbersToString() const { return Join(result, ", "); } +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + +std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { + return os << ToString(kind); +} + string HloInstruction::GatherDimensionNumbersToString() const { CHECK_NE(gather_dimension_numbers_.get(), nullptr); string output_window_dims = @@ -3399,6 +3399,31 @@ bool HloInstruction::CouldBeBitcast() const { } } +Status HloInstruction::GetBackendConfigInternal( + tensorflow::protobuf::Message* proto) const { + proto->Clear(); + + // Empty string does not parse as valid JSON, but it's a valid backend config, + // corresponding to the empty proto. + if (backend_config_.empty()) { + return Status::OK(); + } + return tensorflow::HumanReadableJsonToProto(backend_config_, proto); +} + +Status HloInstruction::set_backend_config( + const tensorflow::protobuf::Message& proto) { + TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto)); + return Status::OK(); +} + +/* static */ StatusOr HloInstruction::BackendConfigToRawString( + const tensorflow::protobuf::Message& proto) { + string ret; + TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret)); + return ret; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3416,21 +3441,78 @@ void HloInstruction::set_outer_dimension_partitions( outer_dimension_partitions_ = outer_dimension_partitions; } -void HloInstruction::RelayoutConstant(const Layout& new_layout, - const ShapeIndex& shape_index) { - CHECK_EQ(opcode(), HloOpcode::kConstant); - Shape* mutable_array_subshape = - ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); - CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); +// TODO(b/80131774): Remove these temporary methods after transition. +int64 HloInstruction::feature_index() const { + return Cast(this)->feature_index(); +} - // Normally array_subshape will always have a layout, but this invariant is - // temporarily broken in LayoutAssignment::AssignLayouts. +float HloInstruction::epsilon() const { + return Cast(this)->epsilon(); +} - if (!mutable_array_subshape->has_layout() || - !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); - *mutable_array_subshape->mutable_layout() = new_layout; - } +FftType HloInstruction::fft_type() const { + return Cast(this)->fft_type(); +} + +const std::vector& HloInstruction::fft_length() const { + return Cast(this)->fft_length(); } +int64 HloInstruction::channel_id() const { + return Cast(this)->channel_id(); +} + +int64 HloInstruction::concatenate_dimension() const { + return Cast(this)->concatenate_dimension(); +} + +bool HloInstruction::IsRank2Transpose() const { + auto transpose = DynCast(this); + return transpose != nullptr && transpose->IsRank2Transpose(); +} + +int64 HloInstruction::slice_starts(int64 dimension) const { + return Cast(this)->slice_starts(dimension); +} + +const std::vector& HloInstruction::slice_starts() const { + return Cast(this)->slice_starts(); +} + +int64 HloInstruction::slice_limits(int64 dimension) const { + return Cast(this)->slice_limits(dimension); +} + +const std::vector& HloInstruction::slice_limits() const { + return Cast(this)->slice_limits(); +} + +int64 HloInstruction::slice_strides(int64 dimension) const { + return Cast(this)->slice_strides(dimension); +} + +const std::vector& HloInstruction::slice_strides() const { + return Cast(this)->slice_strides(); +} + +bool HloInstruction::IsInPlaceSlice() const { + return Cast(this)->IsInPlaceSlice(); +} + +const Literal& HloInstruction::literal() const { + return Cast(this)->literal(); +} + +bool HloInstruction::IsConstant() const { + return DynCast(this) != nullptr; +} + +void HloInstruction::RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index) { + Cast(this)->RelayoutConstant(new_layout, shape_index); +} + +string HloInstruction::TracingTag() const { + return Cast(this)->TracingTag(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 14be58d069e0d8520666766aedc6390bf3d57094..05662ef01bd6114026600c7e8af2607dd421b798 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -50,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -60,23 +63,31 @@ class HloModule; // A bunch of switches that control how the hlo text should be printed. class HloPrintOptions { public: + enum class PrintSubcomputationMode { + kOff, // Do not print anything about subcomputations. + kNameOnly, // Only print the name of subcomputations. + kFullBodies, // Print the full bodies of subcomputations. + }; + // Constructs the default print options: don't print large constants, don't // compact operands, no indentation. HloPrintOptions() : print_large_constants_(false), - print_subcomputation_references_(true), + print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), print_metadata_(true), print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), print_program_shape_(true), print_percent_(true), - indent_amount_(0) {} + canonicalize_instruction_names_(false), + indent_amount_(0), + is_in_nested_computation_(false) {} static HloPrintOptions ShortParsable() { return HloPrintOptions() .set_print_large_constants(true) - .set_print_subcomputation_references(true) + .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) .set_print_metadata(false) .set_print_backend_config(false) .set_print_operand_shape(false) @@ -84,20 +95,28 @@ class HloPrintOptions { .set_print_percent(false); } + // Options to produce the canonical string representing an isomorphic + // computation graph. + static HloPrintOptions Canonical() { + return HloPrintOptions() + .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) + .set_print_metadata(false) + .set_compact_operands(true) + .set_print_operand_shape(true) + .set_print_program_shape(false) + .set_print_percent(false) + .set_canonicalize_instruction_names(true); + } + // If true, large constants will be printed out. HloPrintOptions& set_print_large_constants(bool value) { print_large_constants_ = value; return *this; } - // If true, the names of subcomputations (e.g. a fusion node's fused - // computation) won't be printed. This makes the resulting text not parsable. - // - // A CustomCall's call target is printed even if - // print_subcomputation_references is false, because the call target isn't an - // HloComputation. - HloPrintOptions& set_print_subcomputation_references(bool value) { - print_subcomputation_references_ = value; + HloPrintOptions& set_print_subcomputation_mode( + PrintSubcomputationMode value) { + print_subcomputation_mode_ = value; return *this; } @@ -138,15 +157,29 @@ class HloPrintOptions { return *this; } + // If true, canonicalizes instructions' name. Instead of using "%foo.1" as + // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. + HloPrintOptions& set_canonicalize_instruction_names(bool value) { + canonicalize_instruction_names_ = value; + return *this; + } + // The indent of the hlo text block. HloPrintOptions& set_indent_amount(int value) { indent_amount_ = value; return *this; } + // If true, indicates the instruction being printed is inside a nested + // computation. + HloPrintOptions& set_is_in_nested_computation(bool value) { + is_in_nested_computation_ = value; + return *this; + } + bool print_large_constants() const { return print_large_constants_; } - bool print_subcomputation_references() const { - return print_subcomputation_references_; + PrintSubcomputationMode print_subcomputation_mode() const { + return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } bool print_backend_config() const { return print_metadata_; } @@ -154,39 +187,145 @@ class HloPrintOptions { bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } + bool canonicalize_instruction_names() const { + return canonicalize_instruction_names_; + } int indent_amount() const { return indent_amount_; } + int is_in_nested_computation() const { return is_in_nested_computation_; } private: bool print_large_constants_; - bool print_subcomputation_references_; + PrintSubcomputationMode print_subcomputation_mode_; bool print_metadata_; bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; bool print_program_shape_; bool print_percent_; + bool canonicalize_instruction_names_; int indent_amount_; + bool is_in_nested_computation_; }; -// HLO instructions are the IR used by the high-level compiler. +// For canonical string output, we need to have a canonical way to rename +// each instruction and its operands. Each operand is renamed as "tmp_", +// where is an index starting from 0. +class CanonicalNameMap { + public: + CanonicalNameMap() : index(0) {} + + string LookupOrInsert(const string& old_name) { + auto iter = canonical_name_map.find(old_name); + if (iter != canonical_name_map.end()) { + return iter->second; + } + + string new_name = tensorflow::strings::StrCat("tmp_", index++); + canonical_name_map[old_name] = new_name; + return new_name; + } + void Clear() { + canonical_name_map.clear(); + index = 0; + } + + private: + int64 index; + tensorflow::gtl::FlatMap canonical_name_map; +}; + +// HLO instructions are the atomic unit of the high-level compiler's IR. +// +// HloInstructions live inside of an HloComputation, which is analogous to a +// function in other programming languages. Nodes have no total order within +// their computation. Instead, they have a partial ordering determined by their +// data and control dependencies. +// +// HLO does not have basic blocks or explicit "branch" instructions. Instead, +// certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode +// control flow. For example, the kConditional HLO executes one of two possible +// computations, depending on the runtime value of a predicate. +// +// HLO is pure (mostly). It has no concept of mutable state. Instead, data +// values are produced by one HLO and flow into consumers across dependency +// edges. class HloInstruction { public: + // A fusion node computes the same value a call to its fusion computation + // would compute. However, the choice of fusion kind dictates codegen + // strategy for the backend. + // + // To generate code for a kFusion HloInstruction, most backends do something + // like the following: + // + // 1) Identify the "primary" HloInstruction of the fused computation. + // 2) Emit code that does the work of the primary node, creating its inputs + // and transforming its outputs as specified by the fused computation. + // + // In step (2), the code emitted is usually similar to the code that would be + // emitted for an *unfused* version of the primary node, except that + // + // - when the primary node reads an element of one of its operands, instead + // of loading the value from memory, it *computes* the value based on the + // contents of the fused computation. + // - when the primary node outputs a value, instead of storing it to memory, + // it forwards the value to its users, which then perform additional + // computations before the value is finally stored to memory at the root of + // the fusion node. + // + // An HloInstruction's FusionKind helps us find the kFusion instruction's + // primary node, and can also affect how we generate code in step (2). + // + // - kInput: The primary node is the root of the fused instruction. + // + // - kOutput: The primary node is not the root of the fused instruction. + // This fusion kind requires that one operand buffer of the fusion + // instruction be able to alias the output buffer. This constraint is + // usually enough to let backends find the primary node unambiguously. + // + // - kLoop: The primary node is the root of the fused computation, but, + // unlike in input fusion, we prescribe a specific implementation for + // codegen. Rather than generating code that looks like the code we'd emit + // for an unfused version of the primary/root node, we emit code that + // generates one element of the root at a time. + // + // - kCustom: Custom category for backend-specific fusions that don't fit + // into the above patterns. + // + // Not all backends support all fusion kinds, and given a particular fused + // computation, it's not in general safe to change its fusion kind. Creation + // of fusion nodes is always backend-specific. + // + // For elementwise ops (e.g. kAdd), most backends would emit a + // one-element-at-a-time implementation for the unfused version, so loop + // fusion and input fusion are probably equivalent if the root node is + // elementwise. They're not necessarily equivalent e.g. for kReduce, where an + // implementation might emit something more sophisticated for an unfused or + // input-fusion reduce, but will emit the naive code that reduces one element + // at a time for loop fusion with a reduce as the root. + // + // Another way to think of loop fusion is that it's equivalent to input + // fusion, but where the root node is an implicit identity node, whose + // unfused implementation is "read one element, write one element". + // + // TODO(b/79869434): This categorization scheme is not great. For one thing, + // input and loop fusion are basically the same thing: There is no reason for + // the HLO to encode backend-specific decisions about how e.g. a reduce that's + // the root of a fusion should be lowered. In addition, this scheme as + // written doesn't work for multi-output fusion, where the primary node is + // never actually the root (which is a kTuple instruction that gathers the + // multiple outputs of the fusion). enum class FusionKind { - kLoop, // Fused into a loop. - kInput, // Op's input is fused into the op itself. - kOutput, // Op's output is fused into the op itself. - // REQUIRES: At least one operand buffer must be able - // to alias the output buffer. - kCustom, // Custom category for backend-specific fusions that - // do not match any of the more specific ones. + kLoop, + kInput, + kOutput, + kCustom, }; - ~HloInstruction(); + virtual ~HloInstruction(); // Creates an instruction from the given proto. Arguments: // - // module: the module which will contain the instruction. The newly created - // instruction is *not* added to the module or any computation, however. // proto: the proto to convert from. // instruction_map: a map from instruction id to HloInstruction*. This map // must contain all operands of the newly constructed instruction. @@ -194,7 +333,7 @@ class HloInstruction { // must contain all computations which the newly constructed instruction // calls. static StatusOr> CreateFromProto( - HloModule* module, const HloInstructionProto& proto, + const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, const tensorflow::gtl::FlatMap& computation_map); @@ -287,10 +426,26 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); - // Creates a cross replica sum op. + // Creates a cross replica reduction op. + // + // `reduction_computation`: the reduction function. + // + // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, - tensorflow::gtl::ArraySlice operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::nullopt); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -461,6 +616,13 @@ class HloInstruction { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + // Creates a kDomain instruction which delimits an HLO domain which have + // the provided user and operand side metadata. + static std::unique_ptr CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -502,6 +664,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates a token instruction used for joining or creating token types which + // thread through side-effecting operations. + static std::unique_ptr CreateGenerateToken( + tensorflow::gtl::ArraySlice operands); + // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( tensorflow::gtl::ArraySlice output_window_dims, @@ -512,6 +679,10 @@ class HloInstruction { // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } + // Returns true if this instruction has a side effect, irrespective of whether + // any called computations may contain an instruction with side effects. + bool HasSideEffectNoRecurse() const; + // Returns true if this instruction has a side effect. An instruction has a // side effect if it uses certain opcodes or calls a computation with a side // effect. @@ -536,6 +707,10 @@ class HloInstruction { using InstructionVector = tensorflow::gtl::InlinedVector; const InstructionVector& operands() const { return operands_; } + // Returns the vector of unique operands, in the same order they are found + // within the operand vector. + InstructionVector unique_operands() const; + // Returns the index of 'target' in the operands sequence. // Precondition: target must be an operand (or a fatal error will occur). int64 operand_index(const HloInstruction* target) const; @@ -606,10 +781,8 @@ class HloInstruction { if (opcode() != other.opcode()) { return false; } - using EqShapeFuncType = bool (*)(const Shape&, const Shape&); - EqShapeFuncType eq_shapes = - layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible; - if (!eq_shapes(shape(), other.shape())) { + if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) + : ShapeUtil::Compatible(shape(), other.shape()))) { return false; } if (operands().size() != other.operands().size()) { @@ -624,15 +797,16 @@ class HloInstruction { } } - return IdenticalSlowPath(other, eq_computations, eq_shapes); + if (backend_config_ != other.backend_config_) { + return false; + } + + return IdenticalSlowPath(other, eq_computations); } // Returns whether the instruction has a constant operand. bool HasConstantOperand() const; - // Returns whether this instruction does a rank-2 transposition. - bool IsRank2Transpose() const; - // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. @@ -701,11 +875,6 @@ class HloInstruction { template Status Visit(DfsHloVisitorBase* visitor); - // Returns the literal associated with this instruction. - // - // Note: only constant and parameter opcodes have an associated literal. - const Literal& literal() const; - // Returns the parameter number associated with this instruction. // // Note: only parameter opcodes have an associated parameter number. @@ -714,17 +883,6 @@ class HloInstruction { return parameter_number_; } - // Returns the dimension sizes or numbers associated with this instruction. - // - // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, - // and reverse. - const std::vector& dimensions() const; - int64 dimensions(int64 index) const; - - // Accessor for the dimension in which a concatenate HLO should occur. - // Precondition: opcode() == HloOpcode::kConcatenate - int64 concatenate_dimension() const; - // Returns the tuple index associated with this instruction. // // Precondition: opcode() == HloOpcode::kGetTupleElement @@ -824,7 +982,7 @@ class HloInstruction { string ToShortString() const; // Returns a serialized representation of this instruction. - HloInstructionProto ToProto() const; + virtual HloInstructionProto ToProto() const; // Returns a category for the HLO. This could be something like "convolution" // or "elementwise". @@ -836,47 +994,18 @@ class HloInstruction { HloInstruction* tracing() const; void set_tracing(HloInstruction* trace_instruction); - // Returns the channel id associated with the instruction. The id is - // shared between each Send/Recv pair and is globally unique to identify each - // channel. - // - // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv - int64 channel_id() const { return channel_id_; } - // Returns the channel name associated with the instruction. The name is // used to identify host Send/Recv operations. // // Precondition: opcode() == HloOpcode::kHostCompute string channel_name() const { return channel_name_; } - // Returns feature_index field associated with the instruction. The index - // represents the index of the feature dimension. - // - // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, - // or kBatchNormGrad. - int64 feature_index() const { return feature_index_; } - - // Returns a epsilon value associated with the instruction. The is a small - // number added to the variance to avoid divide-by-zero error. - // - // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, - // or kBatchNormGrad. - float epsilon() const { return epsilon_; } - // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) // and is target-dependent. string infeed_config() const { return infeed_config_; } void set_infeed_config(const string& config) { infeed_config_ = config; } - // Returns a tag to be used in tracing. - // - // Precondition: opcode() == HloOpcode::kTrace - string TracingTag() const; - - // Returns whether the instruction is a constant. - bool IsConstant() const; - // Returns true if this instruction is fused, ie contained within a fusion // instruction. bool IsFused() const; @@ -953,20 +1082,44 @@ class HloInstruction { } // Returns the sharding unique device, if any. tensorflow::gtl::optional sharding_unique_device() const { - if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) { + if (sharding_ == nullptr) { return tensorflow::gtl::optional(); } - return sharding_->UniqueDevice().ValueOrDie(); + auto device = sharding_->UniqueDevice(); + return device.ok() ? device.ValueOrDie() + : tensorflow::gtl::optional(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { sharding_ = MakeUnique(sharding); } + void set_single_sharding(const HloSharding& sharding); + // Sets a sharding that assigns the current instruction to device. + void set_device_sharding(int64 device) { + set_single_sharding(HloSharding::AssignDevice(device)); + } // Remove any sharding from this operator. void clear_sharding() { sharding_ = nullptr; } // Return true if this operator has a sharding assigned. bool has_sharding() const { return sharding_ != nullptr; } + // Checks whether the instruction has compatible sharding with the other + // instruction. + bool has_compatible_sharding(const HloInstruction* other) const { + if (!has_sharding()) { + return !other->has_sharding(); + } + return other->has_sharding() ? sharding() == other->sharding() : false; + } + + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; + } + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; + } // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain @@ -1020,48 +1173,6 @@ class HloInstruction { return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); } - // Returns the start index in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_starts(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_starts_[dimension]; - } - const std::vector& slice_starts() const { return slice_starts_; } - - // Returns the (exclusive) limit index in the given dimension for a slice - // node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_limits(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_[dimension]; - } - const std::vector& slice_limits() const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_; - } - - // Returns the stride in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_strides(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_strides_[dimension]; - } - const std::vector& slice_strides() const { return slice_strides_; } - - // Returns the flag that describes whether a slice must be lowered into an - // offset into the original operand. - bool IsInPlaceSlice() const { return is_in_place_slice_; } - - // Sets and returns the flag that describes whether a slice must be lowered - // into an offset into the original operand. - bool SetIsInPlaceSlice(bool value) { - is_in_place_slice_ = value; - return value; - } - // Returns the size of the slice in the given dimension for a dynamic // slice node. // @@ -1128,19 +1239,6 @@ class HloInstruction { MakeUnique(dnums); } - FftType fft_type() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_type_; - } - - const std::vector& fft_length() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_length_; - } - - // Returns the dump string of the convolution dimension numbers. - string ConvolutionDimensionNumbersToString() const; - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1168,30 +1266,19 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kRng RandomDistribution random_distribution() const; - // See documentation for Clone(). - using CloneMap = std::unordered_map; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of - // the instruction to form the name of the cloned instruction. Ignores the - // control predecessors and successors of this HLO instruction. - // - // If the module pointer is not nullptr, then any cloned computations will be - // added to this module in order to support deep cloning. Otherwise the module - // of the instruction is used. - // - // If clone_map is not nullptr, then each original instruction that is cloned - // will be inserted and map to its clone. clone_map should not already contain - // any of the instructions to clone. - std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr, - CloneMap* clone_map = nullptr) const; + // the instruction to form the name of the cloned instruction. + // Ignores the control predecessors and successors of this HLO instruction. + std::unique_ptr Clone( + const string& suffix = "clone", HloCloneContext* context = nullptr) const; // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr, CloneMap* clone_map = nullptr) const; + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -1231,7 +1318,7 @@ class HloInstruction { bool IsElementwiseOnOperand(int64 operand_idx) const; // Returns true if this instruction is elementwise on all its operands. - bool IsElementwise() const; + virtual bool IsElementwise() const; // Returns true if this elementwise instruction implicitly broadcasts operand // `operand_idx`. @@ -1261,9 +1348,14 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Gets/sets the string identifier for this instruction. + // Gets the string identifier for this instruction. const string& name() const { return name_; } - void set_name(tensorflow::StringPiece name) { name_ = std::string(name); } + + // Sets the string identifier for this instruction. Name will be sanitized to + // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + void SetAndSanitizeName(const string& name) { + name_ = NameUniquer::GetSanitizedName(name); + } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1286,13 +1378,34 @@ class HloInstruction { // this field and they cannot interpret it due to its meaning being backend // specific. // - // TODO(b/78194644): Introduce structured configuration format as per - // go/xla-heuristics. - const string& backend_config() const { return backend_config_; } - void set_backend_config(string backend_config) { - backend_config_ = std::move(backend_config); + // ConfigProto should be a protobuf Message type. + template + StatusOr backend_config() const { + ConfigProto proto; + TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); + return std::move(proto); + } + Status set_backend_config(const tensorflow::protobuf::Message& proto); + + // Getter/setter for raw JSON-encoded backend config. Prefer the + // functions above that deal in proto Messages where possible. + const string& raw_backend_config_string() const { return backend_config_; } + void set_raw_backend_config_string(string config_str) { + backend_config_ = std::move(config_str); } + // Returns a string representation of a proto in the format used by + // raw_backend_config_string. + // + // This is morally equivalent to: + // + // HloInstruction instr; + // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); + // return instr.raw_backend_config_string(); + // + static StatusOr BackendConfigToRawString( + const tensorflow::protobuf::Message& proto); + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1323,47 +1436,135 @@ class HloInstruction { void set_outer_dimension_partitions( const std::vector& outer_dimension_partitions); - // Change the layout for an Constant Hlo instruction to match new_layout. For - // tuple shaped constants shape_index is the path to the internal array - // subshape whose layout needs to be changed. + // Old methods kept for smooth subclassing transition BEGIN. + // TODO(b/80131774): Remove this code. + + // Delegates to HloBatchNormInstruction::feature_index. + int64 feature_index() const; + + // Delegates to HloBatchNormInstruction::epsilon. + float epsilon() const; + + // Delegates to HloFftInstruction::fft_type. + FftType fft_type() const; + + // Delegates to HloFftInstruction::fft_length. + const std::vector& fft_length() const; + + // Delegates to HloSendRecvInstruction::channel_id. + int64 channel_id() const; + + // Returns the dimension sizes or numbers associated with this instruction. + virtual const std::vector& dimensions() const { + LOG(FATAL) << "Unimplemented method."; + } + virtual int64 dimensions(int64 index) const { + LOG(FATAL) << "Unimplemented method."; + } + + // Delegates to HloConcatenateInstruction::concatenate_dimension. + int64 concatenate_dimension() const; + + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + + // Delegates to HloSliceInstruction::slice_start. + int64 slice_starts(int64 dimension) const; + const std::vector& slice_starts() const; + + // Delegates to HloSliceInstruction::slice_limits. + int64 slice_limits(int64 dimension) const; + const std::vector& slice_limits() const; + + // Delegates to HloSliceInstruction::slice_strides. + int64 slice_strides(int64 dimension) const; + const std::vector& slice_strides() const; + + // Delegates to HloSliceInstruction::IsInPlaceSlice. + bool IsInPlaceSlice() const; + + // Returns the literal associated with this instruction. + const Literal& literal() const; + + // Returns whether the instruction is a constant. + bool IsConstant() const; + + // Delegate to HloConstantInstruction::RelayoutConstant. void RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index = {}); + // Delegates to HloTraceInstruction::TracingTag. + string TracingTag() const; + // Old methods kept for smooth subclassing transition END. + + protected: + // Internal constructor for a given opcode/shape, other fields must be filled + // by factory methods. + HloInstruction(HloOpcode opcode, const Shape& shape); + + // Appends operand to the list of operands and adds this instruction as a user + // of the operand. + void AppendOperand(HloInstruction* operand); + + void AppendComputation(HloComputation* computation) { + called_computations_.push_back(computation); + } + private: + // Implementation for non-common logic of CloneWithNewOperands. + virtual std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + // TODO(b/80131774): This should be pure virtual. + LOG(FATAL) << "Unimplemented method."; + } + + // Implementation for non-common logic of ExtraAttributesToString. + virtual std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {}; + } + // Prints an instruction to a string. + // + // The canonical string representation needs to name operands and instruction + // names in a consistent way. This is implemented through the + // canonical_name_map. + string ToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const; + + // Prints an operand to a string. + virtual string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const; + + // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and + // OperandsToStringWithCanonicalNameMap() functions. + friend class HloComputation; + enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; // Helper class for computing OperandElementUse for kFusion. class FusionReusesParamElements; // See comments on Identical(). - // eq_shapes() is used to check shapes for equality, and would normally be - // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on - // whether we want a layout-sensitive check or not. - bool IdenticalSlowPath( + virtual bool IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const; + eq_computations) const; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, tensorflow::gtl::ArraySlice operands); - // Appends operand to the list of operands and adds this instruction as a user - // of the operand. - void AppendOperand(HloInstruction* operand); - // Adds a user for this instruction. void AddUser(HloInstruction* user); // Removes a user for this instruction. void RemoveUser(HloInstruction* user); - // Internal constructor for a given opcode/shape, other fields must be filled - // by factory methods. - HloInstruction(HloOpcode opcode, const Shape& shape); - // Fuses the given instruction into this fusion instruction. When add_output // is false (which is the default), instruction_to_fuse is cloned and the // clone is placed in the fusion instruction. instruction_to_fuse is @@ -1390,15 +1591,15 @@ class HloInstruction { // Clones a fusion instruction with a new shape and operands. std::unique_ptr CloneFusionWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; - - // Returns true if this instruction can legally have the dimensions field - // set. Used for checking precondition of dimensions field accessors. - bool CanHaveDimensionsField() const; + HloCloneContext* context = nullptr) const; // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; + // Helper for implementing backend_config(). Parses backend_config_ into the + // given proto. + Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; + int unique_id_; // Unique to this HloInstruction within a HloModule // Opcode for this instruction. @@ -1429,16 +1630,9 @@ class HloInstruction { // Result shape of this instruction. Shape shape_; - // Literal, only present for kConstant. - std::unique_ptr literal_; - // Constant index, only present for kGetTupleElement. int64 tuple_index_ = -1; - // Dimensions present for some operations that require reshaping or - // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. - std::vector dimensions_; - // Describes the window in a windowed operation such as convolution. std::unique_ptr window_; @@ -1451,20 +1645,6 @@ class HloInstruction { std::unique_ptr gather_dimension_numbers_; std::vector gather_window_bounds_; - // Describes FFT type for an FFT instruction. - FftType fft_type_ = FftType::FFT; - - // Indicates the FFT length for an FFT instruction. - std::vector fft_length_; - - // Describes the [begin, end) index range for a slice. - std::vector slice_starts_; - std::vector slice_limits_; - std::vector slice_strides_; - - // Describes whether the slice can be lowered to an offset into the operand. - bool is_in_place_slice_ = false; - // The bit sizes for a reduce-precision operation. int32 exponent_bits_ = 0; int32 mantissa_bits_ = 0; @@ -1483,6 +1663,10 @@ class HloInstruction { // The sharding, if one exists. std::unique_ptr sharding_; + // Fields used by the kDomain instruction. + std::unique_ptr operand_side_metadata_; + std::unique_ptr user_side_metadata_; + // For parameter instructions this field holds the parameter number. int64 parameter_number_ = 0; @@ -1527,18 +1711,6 @@ class HloInstruction { // Only present for kRng. RandomDistribution distribution_; - // A small float number added to the variance to avoid divide-by-zero error. - // Only present for kBatchNormTraining. - float epsilon_ = 0.0f; - - // An integer value representing the index of the feature dimension. - // Only present for kBatchNormTraining. - int64 feature_index_ = -1; - - // Represents a unique identifier for each Send/Recv instruction pair. - // Only present for kSend or kRecv. - int64 channel_id_ = -1; - // The string representation of the infeed configuration. string infeed_config_; @@ -1568,6 +1740,9 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums); + StatusOr StringToRandomDistribution(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); @@ -1576,13 +1751,20 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // an HloInstruction* or a const HloInstruction*. // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of -// the hlo. +// the hlo. Exception: null pointer values compare less than non-null. // // Note that this cannot be used for HLO instructions across multiple modules // since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, const HloInstruction* const& rhs) const { + if (rhs == nullptr) { + // Nothing compares less than nullptr. + return false; + } + if (lhs == nullptr) { + return true; + } return lhs->unique_id() < rhs->unique_id(); } }; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 909cdc0b6269edaa09806fbe5c2f08197f7dc730..5d6f8b931f0c665fba03e1c845214fa83aabf12e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -24,11 +24,13 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" namespace xla { namespace { @@ -340,7 +342,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Builds a parameter and feeds it to the map. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10, {param0}, add_f32)); module->AddEntryComputation(builder.Build()); @@ -379,7 +381,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { // Builds a parameter and an initial value and feeds them to the reduce. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( @@ -978,6 +980,23 @@ TEST_F(HloInstructionTest, FullyElementwise) { } } +TEST_F(HloInstructionTest, MapIsElementwise) { + auto module = CreateNewModule(); + const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0}); + HloComputation::Builder builder(TestName()); + HloComputation::Builder map_builder("id"); + map_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + auto map_computation = module->AddEmbeddedComputation(map_builder.Build()); + auto x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x")); + auto map = builder.AddInstruction( + HloInstruction::CreateMap(r2f32, {x}, map_computation)); + module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(map->IsElementwise()); +} + TEST_F(HloInstructionTest, PartiallyElementwise) { const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5}); @@ -1336,5 +1355,275 @@ TEST_F(HloInstructionTest, StringifyGather_1) { "index_vector_dim=2, window_bounds={30,29,28,27,26}"); } +TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto options = HloPrintOptions().Canonical(); + + EXPECT_EQ(dot->ToString(options), + "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), " + "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kLoop); + + EXPECT_EQ( + fusion->ToString(options), + R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + +TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({dot, reshape}, + HloInstruction::FusionKind::kLoop); + + HloInstruction* loop = builder.AddInstruction( + HloInstruction::CreateWhile(sout, computation, computation, x)); + + auto options = HloPrintOptions().Canonical(); + EXPECT_EQ(loop->ToString(options), + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, body= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"); +} + +TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({dot, reshape}, + HloInstruction::FusionKind::kLoop); + + builder.AddInstruction( + HloInstruction::CreateWhile(sout, computation, computation, x)); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + sout, pred, x, computation, x, computation)); + auto options = HloPrintOptions().Canonical(); + EXPECT_EQ( + conditional->ToString(options), + R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, false_computation= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"); +} + +TEST_F(HloInstructionTest, CheckDeepClone) { + const char* const hlo_string = R"( +HloModule Module + +addy (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT zadd = s32[] add(lhs, rhs) +} + +calla (x: s32[]) -> s32[] { + x = s32[] parameter(0) + reduce = s32[] reduce-window(x, x), to_apply=addy + ROOT xadd = s32[] add(x, reduce) +} + +body (bparam: s32[]) -> s32[] { + constant = s32[] constant(1) + bparam = s32[] parameter(0) + v = s32[] call(bparam), to_apply=calla + ROOT add = s32[] add(constant, bparam) +} + +condition (cparam: s32[]) -> pred[] { + xconstant = s32[] constant(5) + cparam = s32[] parameter(0) + ROOT greater-than = pred[] greater-than(xconstant, cparam) +} + +ENTRY entry (param: s32[]) -> s32[] { + eparam = s32[] parameter(0) + ROOT while = s32[] while(eparam), condition=condition, body=body + } +)"; + // Check that deep clones really deep clones every instruction and + // computations, without leaving dangling pointers to the old module. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + std::unique_ptr clone = module->Clone(); + for (HloComputation* computation : clone->computations()) { + EXPECT_EQ(computation->parent(), clone.get()); + for (HloInstruction* instruction : computation->instructions()) { + EXPECT_EQ(instruction->parent()->parent(), clone.get()); + } + } +} + +TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) { + const Shape shape = ShapeUtil::MakeShape(F32, {42}); + HloComputation::Builder builder("test"); + HloInstruction* p = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + + EXPECT_TRUE(add1->Identical(*add2)); + add1->set_raw_backend_config_string("abc"); + EXPECT_FALSE(add1->Identical(*add2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + Window w = window_util::MakeWindow({1, 2, 3}); + instr1->set_window(w); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr1->set_convolution_dimension_numbers(dnums); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, CloneWindowOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + Window w = window_util::MakeWindow({1, 2, 3}); + instr->set_window(w); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w)) + << clone->window().DebugString(); +} + +TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr->set_convolution_dimension_numbers(dnums); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals( + clone->convolution_dimension_numbers(), dnums)) + << clone->convolution_dimension_numbers().DebugString(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc new file mode 100644 index 0000000000000000000000000000000000000000..1815bf1b1677f1ec3483c543a28e56dbf566c976 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -0,0 +1,691 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" + +namespace xla { + +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +HloBatchNormInstruction::HloBatchNormInstruction( + HloOpcode opcode, const Shape& shape, HloInstruction* operand, + HloInstruction* scale, float epsilon, int64 feature_index) + : HloInstruction(opcode, shape), + epsilon_(epsilon), + feature_index_(feature_index) { + AppendOperand(operand); + AppendOperand(scale); +} + +bool HloBatchNormInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return feature_index() == casted_other.feature_index() && + epsilon() == casted_other.epsilon(); +} + +HloInstructionProto HloBatchNormInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_epsilon(epsilon_); + proto.set_feature_index(feature_index_); + return proto; +} + +std::vector HloBatchNormInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("epsilon=", epsilon()), + StrCat("feature_index=", feature_index())}; +} + +HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand, + scale, epsilon, feature_index) { + AppendOperand(offset); +} + +std::unique_ptr +HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), + feature_index()); +} + +HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand, + scale, epsilon, feature_index) { + AppendOperand(offset); + AppendOperand(mean); + AppendOperand(variance); +} + +std::unique_ptr +HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 5); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); +} + +HloBatchNormGradInstruction::HloBatchNormGradInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output, + float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale, + epsilon, feature_index) { + AppendOperand(mean); + AppendOperand(variance); + AppendOperand(grad_output); +} + +std::unique_ptr +HloBatchNormGradInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 5); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); +} + +HloFftInstruction::HloFftInstruction( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length) + : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) { + fft_length_.assign(fft_length.begin(), fft_length.end()); + AppendOperand(operand); +} + +HloInstructionProto HloFftInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_fft_type(fft_type_); + for (int64 fft_len : fft_length_) { + proto.add_fft_length(fft_len); + } + return proto; +} + +std::vector HloFftInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("fft_type=", FftType_Name(fft_type())), + StrCat("fft_length={", Join(fft_length(), ","), "}")}; +} + +bool HloFftInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return fft_type() == casted_other.fft_type() && + fft_length() == casted_other.fft_length(); +} + +std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], fft_type_, + fft_length_); +} + +HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, + const Shape& shape, + int64 channel_id) + : HloInstruction(opcode, shape), channel_id_(channel_id) {} + +HloInstructionProto HloSendRecvInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_channel_id(channel_id_); + return proto; +} + +std::vector HloSendRecvInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("channel_id=", channel_id_)}; +} + +bool HloSendRecvInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +// Send instruction produces a tuple of {aliased operand, U32 context}. +HloSendInstruction::HloSendInstruction(HloInstruction* operand, + int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kSend, + ShapeUtil::MakeTupleShape( + {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}), + channel_id) { + AppendOperand(operand); +} + +std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(new_operands[0], channel_id()); +} + +HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) + : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil(), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloSendDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +// Recv instruction produces a tuple of {receive buffer, U32 context}. +HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kRecv, + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), + channel_id) {} + +std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 0); + return MakeUnique( + ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); +} + +HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) + : HloSendRecvInstruction( + HloOpcode::kRecvDone, + ShapeUtil::GetTupleElementShape(operand->shape(), 0), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloRecvDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +HloReverseInstruction::HloReverseInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kReverse, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloReverseInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReverseInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReverseInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloConcatenateInstruction::HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension) + : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloConcatenateInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloConcatenateInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloConcatenateInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloConcatenateInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, + dimensions(0)); +} + +HloReduceInstruction::HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation) + : HloInstruction(HloOpcode::kReduce, shape), + dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { + AppendOperand(arg); + AppendOperand(init_value); + AppendComputation(reduce_computation); +} + +HloInstructionProto HloReduceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReduceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + // Reduction results are determined by the reduction dimension and the + // reduction computation. + return dimensions() == casted_other.dimensions() && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], dimensions(), to_apply()); +} + +HloTransposeInstruction::HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kTranspose, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + CHECK_EQ(shape.dimensions().size(), dimensions.size()); + CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); + CHECK(std::equal(operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(dimensions, shape.dimensions()).begin())) + << "shape: " << ShapeUtil::HumanString(shape) + << ", operand->shape(): " << ShapeUtil::HumanString(shape) + << ", dimensions: {" << Join(dimensions, ", ") << "}"; + AppendOperand(operand); +} + +bool HloTransposeInstruction::IsRank2Transpose() const { + return dimensions() == std::vector({1, 0}) && + shape().dimensions_size() == 2 && + std::equal(shape().dimensions().begin(), shape().dimensions().end(), + operand(0)->shape().dimensions().rbegin()); +} + +HloInstructionProto HloTransposeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloTransposeInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloTransposeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloTransposeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloBroadcastInstruction::HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension) + : HloInstruction(HloOpcode::kBroadcast, shape), + dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloBroadcastInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloBroadcastInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloBroadcastInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloBroadcastInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloMapInstruction::HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands) + : HloInstruction(HloOpcode::kMap, shape) { + CHECK(static_operands.empty()) << "static_operands not yet supported"; + for (auto operand : operands) { + AppendOperand(operand); + } + AppendComputation(map_computation); + // TODO(b/65689298) Remove code below once Map is generalized to accept + // arbitrary map dimensions. + dimensions_.resize(ShapeUtil::Rank(shape)); + std::iota(dimensions_.begin(), dimensions_.end(), 0); +} + +HloInstructionProto HloMapInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +bool HloMapInstruction::IsElementwise() const { + if (!dimensions().empty()) { + // Check that the map is executed in elementwise compatible dimensions. + if (dimensions().size() != shape().dimensions_size()) { + return false; + } + for (int i = 0; i < dimensions().size(); ++i) { + if (dimensions()[i] != i) { + return false; + } + } + } + return true; +} + +std::vector HloMapInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloMapInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return eq_computations(to_apply(), other.to_apply()); +} + +std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, to_apply()); +} + +HloSliceInstruction::HloSliceInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) + : HloInstruction(HloOpcode::kSlice, shape), + slice_starts_(start_indices.begin(), start_indices.end()), + slice_limits_(limit_indices.begin(), limit_indices.end()), + slice_strides_(strides.begin(), strides.end()) { + AppendOperand(operand); + // For backward compatibility with old serialized computations: if there are + // no strides, assume all strides are 1. + // TODO(b/63317920): remove this code. + if (slice_strides_.empty()) { + slice_strides_ = std::vector(start_indices.size(), 1LL); + } +} + +HloInstructionProto HloSliceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int i = 0; i < slice_starts_.size(); ++i) { + auto* slice_dimension = proto.add_slice_dimensions(); + slice_dimension->set_start(slice_starts_[i]); + slice_dimension->set_limit(slice_limits_[i]); + slice_dimension->set_stride(slice_strides_[i]); + } + return proto; +} + +std::vector HloSliceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector bounds; + bounds.reserve(slice_starts_.size()); + const bool omit_stride = + std::all_of(slice_strides_.begin(), slice_strides_.end(), + [](int64 stride) { return stride == 1; }); + for (int i = 0; i < slice_starts_.size(); ++i) { + string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); + bounds.push_back( + StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]")); + } + return {StrCat("slice={", Join(bounds, ", "), "}")}; +} + +bool HloSliceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return slice_starts_ == other_slice.slice_starts_ && + slice_limits_ == other_slice.slice_limits_ && + slice_strides_ == other_slice.slice_strides_; +} + +std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], slice_starts_, + slice_limits_, slice_strides_); +} + +HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) + : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), + literal_(std::move(literal)) {} + +HloInstructionProto HloConstantInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_literal() = literal_->ToProto(); + return proto; +} + +bool HloConstantInstruction::IsElementwise() const { return true; } + +void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index) { + Shape* mutable_array_subshape = + ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); + CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + + // Normally array_subshape will always have a layout, but this invariant is + // temporarily broken in LayoutAssignment::AssignLayouts. + + if (!mutable_array_subshape->has_layout() || + !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { + literal_ = literal_->Relayout(new_layout, shape_index); + *mutable_array_subshape->mutable_layout() = new_layout; + } +} + +bool HloConstantInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return literal() == other_slice.literal(); +} + +std::unique_ptr +HloConstantInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(literal_->CloneToUnique()); +} + +string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + string operands; + // For constants, show the actual value in place of an empty operand list. + if ((!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || + options.print_large_constants()) { + // Literal::ToString emits multidimensional arrays over multiple + // lines. Compact this into one line by stripping out white space. + string tmp = literal().ToString(); + std::replace(tmp.begin(), tmp.end(), '\n', ' '); + std::vector v = tensorflow::str_util::Split(tmp, ' '); + bool first = true; + // Concatenate elements in "v" with spaces separating them, but ignoring + // empty entries. + for (const auto& s : v) { + if (s.empty()) { + continue; + } + StrAppend(&operands, (first ? "" : " "), s); + first = false; + } + } else { + // Do not show large constants or tuples. + operands = "{...}"; + } + return operands; +} + +HloTraceInstruction::HloTraceInstruction(const string& tag, + HloInstruction* operand) + : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()), + literal_(Literal::CreateR1U8(tag)) { + AppendOperand(operand); + operand->set_tracing(this); +} + +HloInstructionProto HloTraceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_literal() = literal_->ToProto(); + return proto; +} + +bool HloTraceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return false; +} + +std::unique_ptr HloTraceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode()); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..ecd4a319128a3b239c008d7cf5cea38c9c1ce6ca --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -0,0 +1,494 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// All HloInstruction subclasses are put in this file. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +class HloBatchNormInstruction : public HloInstruction { + public: + // Returns feature_index field associated with the instruction. The index + // represents the index of the feature dimension. + int64 feature_index() const { return feature_index_; } + + // Returns a epsilon value associated with the instruction. The is a small + // number added to the variance to avoid divide-by-zero error. + float epsilon() const { return epsilon_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + protected: + explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, float epsilon, + int64 feature_index); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // A small float number added to the variance to avoid divide-by-zero error. + float epsilon_ = 0.0f; + + // An integer value representing the index of the feature dimension. + int64 feature_index_ = -1; +}; + +class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormTrainingInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, + HloInstruction* offset, + float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormInferenceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloBatchNormGradInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormGradInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, + HloInstruction* grad_output, float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloFftInstruction : public HloInstruction { + public: + explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, + FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + FftType fft_type() const { return fft_type_; } + + const std::vector& fft_length() const { return fft_length_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes FFT type for an FFT instruction. + FftType fft_type_ = FftType::FFT; + + // Indicates the FFT length for an FFT instruction. + std::vector fft_length_; +}; + +class HloSendRecvInstruction : public HloInstruction { + public: + // Returns the channel id associated with the instruction. The id is + // shared between each Send/Recv pair and is globally unique to identify each + // channel. + int64 channel_id() const { return channel_id_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + protected: + explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, + int64 channel_id); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Represents a unique identifier for each Send/Recv instruction pair. + int64 channel_id_; +}; + +class HloSendInstruction : public HloSendRecvInstruction { + public: + explicit HloSendInstruction(HloInstruction* operand, int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloSendDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloSendDoneInstruction(HloSendInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvInstruction(const Shape& shape, int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvDoneInstruction(HloRecvInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloReverseInstruction : public HloInstruction { + public: + explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloConcatenateInstruction : public HloInstruction { + public: + explicit HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Accessor for the dimension in which a concatenate HLO should occur. + int64 concatenate_dimension() const { return dimensions(0); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloReduceInstruction : public HloInstruction { + public: + explicit HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloTransposeInstruction : public HloInstruction { + public: + explicit HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloBroadcastInstruction : public HloInstruction { + public: + explicit HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloMapInstruction : public HloInstruction { + public: + explicit HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands = {}); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Returns true if this instruction is binary and elementwise. + bool IsElementwise() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloSliceInstruction : public HloInstruction { + public: + explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + HloInstructionProto ToProto() const override; + + // Returns the start index in the given dimension for a slice node. + int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; } + const std::vector& slice_starts() const { return slice_starts_; } + + // Returns the (exclusive) limit index in the given dimension for a slice + // node. + int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; } + const std::vector& slice_limits() const { return slice_limits_; } + + // Returns the stride in the given dimension for a slice node. + int64 slice_strides(int64 dimension) const { + return slice_strides_[dimension]; + } + const std::vector& slice_strides() const { return slice_strides_; } + + // Returns the flag that describes whether a slice must be lowered into an + // offset into the original operand. + bool IsInPlaceSlice() const { return is_in_place_slice_; } + + // Sets and returns the flag that describes whether a slice must be lowered + // into an offset into the original operand. + bool SetIsInPlaceSlice(bool value) { + is_in_place_slice_ = value; + return value; + } + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes the [begin, end) index range for a slice. + std::vector slice_starts_; + std::vector slice_limits_; + std::vector slice_strides_; + + // Describes whether the slice can be lowered to an offset into the operand. + bool is_in_place_slice_ = false; +}; + +class HloConstantInstruction : public HloInstruction { + public: + explicit HloConstantInstruction(std::unique_ptr literal); + // Returns the literal associated with this instruction. + const Literal& literal() const { return *literal_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + // Returns true if this instruction is elementwise on all its operands. + bool IsElementwise() const override; + + // Change the layout for an Constant Hlo instruction to match new_layout. For + // tuple shaped constants shape_index is the path to the internal array + // subshape whose layout needs to be changed. + void RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index = {}); + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +class HloTraceInstruction : public HloInstruction { + public: + explicit HloTraceInstruction(const string& tag, HloInstruction* operand); + // Returns a tag to be used in tracing. + string TracingTag() const { return literal_->GetR1U8AsString(); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc similarity index 95% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.cc rename to tensorflow/compiler/xla/service/hlo_lexer.cc index 350db126535e418cbfa914edd958f47ba90a3ee5..f0d9fdbc8f86da0bb9d7f9235239df677c9506bc 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include @@ -26,9 +26,8 @@ limitations under the License. #include "tensorflow/core/platform/regexp.h" namespace xla { -namespace tools { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; namespace { @@ -67,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const { return ptr < buf_.end() && ptr >= buf_.begin(); } -StringPiece HloLexer::StringPieceFromPointers(const char* begin, - const char* end) const { +tensorflow::StringPiece HloLexer::StringPieceFromPointers( + const char* begin, const char* end) const { CHECK(begin <= end); CHECK(begin == buf_.end() || CanDereference(begin)); CHECK(end == buf_.end() || CanDereference(end)); - return StringPiece(begin, end - begin); + return tensorflow::StringPiece(begin, end - begin); } tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( @@ -197,7 +196,8 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kAttributeName; } - StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); + tensorflow::StringPiece identifier = + StringPieceFromPointers(token_start_, current_ptr_); // See if this is a keyword. #define KEYWORD(STR) \ @@ -332,23 +332,24 @@ std::pair HloLexer::GetLineAndColumn(LocTy location) const { line_no_cache_.last_query = ptr; line_no_cache_.line_no_of_query = line_no; size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); - if (line_offset == StringPiece::npos) { + if (line_offset == tensorflow::StringPiece::npos) { line_offset = 0; } return {line_no, ptr - start - line_offset}; } -StringPiece HloLexer::GetLine(LocTy loc) const { +tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { if (!CanDereference(loc)) { return "LINE OUT OF RANGE"; } size_t line_start = StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); - const char* start = line_start == StringPiece::npos + const char* start = line_start == tensorflow::StringPiece::npos ? buf_.begin() : buf_.begin() + line_start + 1; size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); - const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end; + const char* end = + line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end; return StringPieceFromPointers(start, end); } @@ -370,7 +371,7 @@ TokKind HloLexer::LexString() { static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); - StringPiece raw = + tensorflow::StringPiece raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { @@ -453,5 +454,4 @@ string TokKindToString(TokKind kind) { } } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h similarity index 90% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.h rename to tensorflow/compiler/xla/service/hlo_lexer.h index 27880b9b8afbfa58abfedc3b2cecd5236b78a6d6..ceb674f25e94ac3ac2e6a4a0687a93ffdcd065e0 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ #include -#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -27,9 +27,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Lexer for the HloModule::ToString() format text. +// +// This class is meant to be used by hlo_parser.cc. You shouldn't need to use +// it directly. class HloLexer { public: explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { @@ -57,7 +59,7 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - int64 GetInt64Val() const { + tensorflow::int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; } @@ -114,7 +116,7 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - int64 int64_val_; + tensorflow::int64 int64_val_; double decimal_val_; struct LineNoCacheTy { @@ -125,7 +127,6 @@ class HloLexer { mutable LineNoCacheTy line_no_cache_{nullptr, 0}; }; -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..43c41ece6efc4f9e8ca74f16e0f63d29abc4de4e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -0,0 +1,306 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using Worklist = std::deque; +using Workset = std::unordered_set; + +namespace { + +void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, + Workset* workset) { + if (workset->count(instruction) == 0) { + worklist->push_back(instruction); + workset->insert(instruction); + VLOG(3) << "ADD instruction: " << instruction->name(); + } +} + +using VisitorFunction = std::function; + +void ForEachLiveIndex(const ShapeTree& index_tree, + const VisitorFunction& func) { + index_tree.ForEachElement([&](const ShapeIndex& shape_index, bool live) { + if (live) { + func(shape_index); + } + }); +} + +// Marks 'instruction' output live at 'shape_index'. +// Adds to 'worklist' iff: +// *) 'instruction' is not already on worklist. +// *) 'shape_index' has not yet been visited. +void MarkLiveAtIndex(const HloInstruction* instruction, + const ShapeIndex& shape_index, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + auto it_added = live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); + it = it_added.first; + } + if (it->second.element(shape_index) == false) { + AddToWorklist(instruction, worklist, workset); + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } +} + +// Marks 'instruction' live at all shape indices in its output. +void MarkLiveAtAllIndices(const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + bool add_to_worklist = false; + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/true)); + add_to_worklist = true; + } else { + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&](const Shape& sub_shape, const ShapeIndex& shape_index) { + if (it->second.element(shape_index) == false) { + add_to_worklist = true; + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } + }); + } + if (add_to_worklist) { + AddToWorklist(instruction, worklist, workset); + } +} + +// Propagates liveness through Tuple instructions. +// *) For each tuple operand: +// *) For tuple output shape index associated with operand: +// *) Propgate live shape indices to tuple operand at the associated +// shape index in the operands output, and add to worklist. +void PropagateLivenessThroughTuple( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kTuple); + for (int64 operand_index = 0; operand_index < instruction->operand_count(); + ++operand_index) { + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + if (shape_index.empty() || shape_index[0] != operand_index) { + return; + } + // Mark top-level index of operand at 'operand_index'. + MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map, + worklist, workset); + // Mark sub-shape index of operand at 'operand_index'. + ShapeIndex operand_shape_index; + for (int i = 1; i < shape_index.size(); ++i) { + operand_shape_index.push_back(shape_index[i]); + } + MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index, + live_index_map, worklist, workset); + }); + } +} + +// Propagates liveness through GetTupleElement instructions. +// *) For each live index in GetTupleElement output, mark output of GTE operand +// at associated shape index in its output, and add to worklist. +void PropagateLivenessThroughGTE( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement); + // Mark operand top-level index. + MarkLiveAtIndex(instruction->operand(0), {}, live_index_map, worklist, + workset); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + // Propagate live shape indices along GTE -> Tuple edge. + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + ShapeIndex operand_shape_index(shape_index); + operand_shape_index.push_front(instruction->tuple_index()); + MarkLiveAtIndex(instruction->operand(0), operand_shape_index, + live_index_map, worklist, workset); + }); +} + +// Propagates liveness through While instructions. +// *) For each live index in While output, mark shape index of while.body.root +// and while.operand (adding each to worklist). +// *) Mark while.cond.root and add to worklist. +void PropagateLivenessThroughWhile( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kWhile); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while body computation root instruction. + MarkLiveAtIndex(instruction->while_body()->root_instruction(), shape_index, + live_index_map, worklist, workset); + // Propagate liveness to tuple-shaped operand. + MarkLiveAtIndex(instruction->operand(0), shape_index, live_index_map, + worklist, workset); + }); + + // Propagate liveness to while condition computation root instruction. + MarkLiveAtIndex(instruction->while_condition()->root_instruction(), {}, + live_index_map, worklist, workset); +} + +// Propagates liveness out of Parameter instructions to callers and aliasing +// positions. This can occur if liveness propagates to a parameter in the +// while.condition computation, requiring liveness to propagate out to caller +// callsite while (and while.body.root). +void PropagateLivenessToParameterCallers( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset, CallGraph* call_graph) { + CHECK_EQ(instruction->opcode(), HloOpcode::kParameter); + const CallGraphNode& call_graph_node = + call_graph->GetNode(instruction->parent()); + if (call_graph_node.context() == CallContext::kSequential) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + auto* xla_while = callsite.instruction(); + const ShapeTree& index_tree = + FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while result{shape_index} + MarkLiveAtIndex(xla_while, shape_index, live_index_map, worklist, + workset); + // Propagate liveness to while body root{shape_index}. + MarkLiveAtIndex(xla_while->while_body()->root_instruction(), + shape_index, live_index_map, worklist, workset); + // Propagate liveness to operand(0){shape_index}. + MarkLiveAtIndex(xla_while->operand(0), shape_index, live_index_map, + worklist, workset); + }); + } + } + } +} + +} // namespace + +HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module) + : module_(module), call_graph_(CallGraph::Build(&module)) {} + +// Runs liveness analysis on 'module_'. +// Initializes worklist with entry root instruction (and any instruction with +// side-effects), marking all of their output shape indices live. +// Visits elements on worklist, propagating liveness from an instructions +// live output shape indices to its called computations and operands. +void HloLivenessAnalysis::RunAnalysis() { + Worklist worklist; + Workset workset; + // Add entry compuation root instruction. + MarkLiveAtAllIndices(module_.entry_computation()->root_instruction(), + &live_index_map_, &worklist, &workset); + for (auto* computation : module_.computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->HasSideEffectNoRecurse()) { + // Add instructions with side effects. + MarkLiveAtAllIndices(instruction, &live_index_map_, &worklist, + &workset); + } + } + } + + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop_front(); + workset.erase(workset.find(instruction)); + VLOG(1) << "VISIT instruction: " << instruction->name(); + + if (instruction->opcode() == HloOpcode::kTuple) { + PropagateLivenessThroughTuple(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { + PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kWhile && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kParameter && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessToParameterCallers(instruction, &live_index_map_, + &worklist, &workset, + call_graph_.get()); + } else { + // Propagate liveness to called computations. + for (auto* called_computation : instruction->called_computations()) { + MarkLiveAtAllIndices(called_computation->root_instruction(), + &live_index_map_, &worklist, &workset); + } + // Propagate liveness to operands. + for (HloInstruction* operand : instruction->operands()) { + MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset); + } + } + } +} + +bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const { + if (ContainsKey(live_index_map_, instruction)) { + return FindOrDie(live_index_map_, instruction).element(shape_index); + } + return false; +} + +/* static */ +StatusOr> HloLivenessAnalysis::Run( + const HloModule& module) { + VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); + XLA_VLOG_LINES(2, module.ToString()); + + auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module)); + + liveness_analysis->RunAnalysis(); + + return std::move(liveness_analysis); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.h b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..fe55a8070a42a3d68836dd32cf7ce5823dd77951 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Analysis which identifies all live {HloInstruction, ShapeIndex} pairs in +// an HLO module. +// +// HloLivenessAnalysis marks the shape index of each live output of each +// instruction in the module, by propagating live shape index information +// from an instruction to its called computations and operands. +class HloLivenessAnalysis { + public: + // Maps from an HloInstruction to its live/dead output shape indices. + using HloIndexMap = + std::unordered_map>; + + // Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object + // which exports liveness for each {HloInstruction, ShapeIndex} in 'module'. + static StatusOr> Run( + const HloModule& module); + + // Returns true if output of 'instruction' at 'shape_index' is live. + // Returns false otherwise. + bool IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const; + + private: + HloLivenessAnalysis(const HloModule& module); + + void RunAnalysis(); + + const HloModule& module_; + std::unique_ptr call_graph_; + HloIndexMap live_index_map_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0275294a1a86cef13e5b267ad578f30cc18858dc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -0,0 +1,402 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class HloLivenessAnalysisTest : public HloTestBase { + protected: + HloLivenessAnalysisTest() {} + + // Run liveness analysis on the member module. For convenience returns a + // reference to the generated analysis stored in analysis_. + const HloLivenessAnalysis& RunLiveness(HloModule* module) { + liveness_ = HloLivenessAnalysis::Run(*module).ConsumeValueOrDie(); + return *liveness_; + } + + HloInstruction* GetInstruction(HloModule* module, const string& name) { + HloInstruction* to_return = nullptr; + for (auto* comp : module->computations()) { + for (auto* inst : comp->instructions()) { + if (inst->name() == name) { + to_return = inst; + break; + } + } + } + return CHECK_NOTNULL(to_return); + } + + std::unique_ptr liveness_; +}; + +// Test that add instruction at entry root is live at all output shape indices. +TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT add = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Test that a dead add instruction is marked as dead by analysis. +TEST_F(HloLivenessAnalysisTest, DeadAdd) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + add.1 = s32[] add(constant.1, constant.2) + ROOT add.2 = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {})); +} + +// Test that all output shape indices of entry root tuple (and defining +// instruction in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that all outputs of nested tuple and entry root (and defining +// instruction values appearing in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(1) + constant.2 = s32[] constant(2) + constant.3 = s32[] constant(3) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + ROOT tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE at entry root of Tuple instruction only propgates liveness +// to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + ROOT get-tuple-element.1 = s32[] get-tuple-element(tuple.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that GTE at entry root of nested Tuple instruction only propgates +// liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + ROOT get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE of GTE (at entry root) of nested Tuple instruction only +// propgates liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + ROOT get-tuple-element.2 = s32[] get-tuple-element(get-tuple-element.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.2"), {})); + + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_FALSE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_FALSE( + liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Test that live/dead while tuple elements are marked live/dead correctly. +TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.4"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a tuple element live in while.cond computation, propagates +// liveness to while.body.root/while.result/while.operand (where it is unused). +TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1 + add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4) + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(add.1, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.4"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a use of while.result{0} propagates liveness to +// while.body.param{1} to while.body.root{1}, and then to while.body.param{2}. +TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.1), index=2 + multiply.1 = s32[] multiply(get-tuple-element.3, get-tuple-element.3) + ROOT tuple.1 = (s32[], s32[], s32[]) tuple(add.1, get-tuple-element.3, multiply.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 + constant.1 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1) + } + ENTRY SimpleLoop { + constant.2 = s32[] constant(0) + constant.3 = s32[] constant(1) + constant.4 = s32[] constant(2) + tuple.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.3, constant.4) + while.1 = (s32[], s32[], s32[]) while(tuple.2), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.1), index=0 + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {2})); + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {2})); + // While body root. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {2})); + // While body param. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index c33bdadf1c7145bf2aff09b01423c6c21382da0c..c570b420c21fed4d7828feb24ee5c7859db94a79 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -324,6 +325,12 @@ inline ::testing::Matcher Sharding( return ::testing::MakeMatcher( new ::xla::testing::HloShardingMatcher(sharding)); } +// Matcher for Sharding from sharding string +inline ::testing::Matcher Sharding( + tensorflow::StringPiece sharding) { + return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher( + ParseSharding(sharding).ValueOrDie())); +} // Verifies that no HloSharding is set for an HLO instruction. inline ::testing::Matcher NoSharding() { return ::testing::MakeMatcher( diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 016cc01e33840aa195dfc0a21e8ac8f3d24a3e06..9a3010cf1ff75e840130d8442bbe26d6041cef25 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace op = xla::testing::opcode_matchers; using ::testing::_; @@ -147,6 +147,18 @@ TEST(HloMatchersTest, ShardingMatcher) { "param.1"); p1->set_sharding(HloSharding::AssignDevice(1)); + auto tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {7}), ShapeUtil::MakeShape(S32, {9}), + ShapeUtil::MakeShape(F32, {11})}); + auto p2 = HloInstruction::CreateParameter(1, tuple_shape, "param.2"); + Array assignment({2}); + assignment.SetValues({0, 1}); + auto sharding = HloSharding::Tuple( + tuple_shape, + {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), + HloSharding::AssignDevice(1), HloSharding::Replicate()}); + p2->set_sharding(sharding); + EXPECT_THAT(p0.get(), op::NoSharding()); EXPECT_THAT(p0.get(), ::testing::Not(op::Sharding(HloSharding::AssignDevice(1)))); @@ -155,6 +167,11 @@ TEST(HloMatchersTest, ShardingMatcher) { ::testing::Not(op::Sharding(HloSharding::AssignDevice(0)))); EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1))); + EXPECT_THAT( + p2.get(), + op::Sharding( + "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " "{maximal device=1})"); @@ -178,7 +195,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1), diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 5308fb5848341b6faee64bc1ad865f9bb3bcdbe9..9c59374b4a9d7e3dbfb99d8a6b30d4230e553658 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -32,15 +32,6 @@ limitations under the License. namespace xla { -HloModule::HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config) - : name_(NameUniquer::GetSanitizedName(name)), - config_(config), - has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle), - unique_id_(next_unique_module_id_++) {} - HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(NameUniquer::GetSanitizedName(name)), config_(config), @@ -234,8 +225,7 @@ HloModuleProto HloModule::ToProto() const { /* static */ StatusOr> HloModule::CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle) { + const HloModuleProto& proto, const HloModuleConfig& module_config) { // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) @@ -266,24 +256,43 @@ StatusOr> HloModule::CreateFromProto( << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); - auto module = MakeUnique(proto.name(), entry_computation_handle, - module_config); - tensorflow::gtl::FlatMap computation_map; + tensorflow::gtl::FlatMap to_proto_id; + std::vector> computations; + HloComputation* entry = nullptr; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, computation_map)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr computation, + HloComputation::CreateFromProto(computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); int64 computation_id = computation_proto.id(); TF_RET_CHECK(computation_id != -1); TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); + computation_map[computation_id] = computation.get(); + to_proto_id[computation.get()] = computation_id; + if (computation_id == proto.entry_computation_id()) { + entry = computation.get(); + } + computations.push_back(std::move(computation)); + } + TF_RET_CHECK(entry != nullptr); + + auto module = MakeUnique(proto.name(), module_config); + + // Sort the computations in the proto id's order. + std::sort(computations.begin(), computations.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + + // Add sorted computations to the module. + for (auto& computation : computations) { + bool is_entry = computation.get() == entry; // Don't uniquify names because we want names to be stable across // serialization and deserialization. - computation_map[computation_id] = module->AddComputationInternal( - std::move(computation), - /*is_entry=*/proto.entry_computation_id() == computation_id, - /*uniquify_names=*/false); + module->AddComputationInternal(std::move(computation), is_entry, + /*uniquify_names=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); @@ -381,7 +390,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( // as a parameter in the new function. arguments.push_back(old_operand); *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter( - parameter_count, old_operand->shape(), "")); + parameter_count, old_operand->shape(), "p")); ++parameter_count; } TF_CHECK_OK( @@ -476,7 +485,18 @@ std::list HloModule::MakeComputationPostOrder() const { added_computations.insert(computation.get()); } } - CHECK_EQ(post_order.size(), computations_.size()); + if (post_order.size() != computations_.size()) { + for (HloComputation* computation : post_order) { + LOG(ERROR) << "Post Order: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + for (auto& computation : computations_) { + LOG(ERROR) << "Computations: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size() + << " computation_count=" << computations_.size(); + } return post_order; } @@ -494,57 +514,26 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; auto module = MakeUnique(name_ + "-" + suffix, config_); - module->entry_computation_handle_ = entry_computation_handle_; - module->has_entry_computation_handle_ = has_entry_computation_handle_; - std::unordered_map clone_map; - for (auto& computation : computations_) { - if (computation->IsFusionComputation()) { - // Cloning of a fused computation is handled by its fusion instruction. - continue; - } - - // When cloning a computation, pass in the new module, so that for any - // fusion instruction in this computation, the fused computation will be - // deep cloned to the new module. - auto cloned_computation = computation->Clone(suffix, module.get()); - InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); - - if (entry_computation_ == computation.get()) { - module->AddEntryComputation(std::move(cloned_computation)); - } else { - module->AddEmbeddedComputation(std::move(cloned_computation)); - } - } - - for (auto& cloned_computation : module->computations_) { - for (auto* instruction : cloned_computation->instructions()) { - // Rewrite instruction's called_computation to point to the cloned - // computations. - instruction->ReplaceCalledComputations([&](HloComputation* hlo) { - if (hlo->IsFusionComputation()) { - // Cloning of a fused computation has already been handled when its - // fusion instruction is cloned. So this hlo computation is already - // the cloned one. - return hlo; - } - return FindOrDie(clone_map, hlo); - }); - } - } + HloCloneContext context(module.get(), suffix); + auto cloned_computation = entry_computation_->Clone(suffix, &context); + module->AddEntryComputation(std::move(cloned_computation)); return module; } -HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) { - HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this)); - TF_CHECK_OK( - clone->root_instruction()->Accept([this](HloInstruction* instruction) { - instruction->ReplaceCalledComputations([this](HloComputation* callee) { - return DeepCloneComputation(callee); - }); - return Status::OK(); - })); - return clone; +HloComputation* HloModule::DeepCloneComputation(HloComputation* computation, + HloCloneContext* context) { + HloComputation* new_computation; + if (context != nullptr) { + if ((new_computation = context->FindComputation(computation)) != nullptr) { + return new_computation; + } + new_computation = + AddEmbeddedComputation(computation->Clone(context->suffix(), context)); + } else { + new_computation = AddEmbeddedComputation(computation->Clone("")); + } + return new_computation; } uint64 HloModule::RandomNew64() const { diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 1604a7261240e5bfef7c4fb1583488bf0ae4421a..757e65bda286d983d05e5a791aa7dffe97bac945 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -26,11 +26,11 @@ limitations under the License. #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -42,16 +42,20 @@ namespace xla { // Describes a compilation unit at the HLO level. // -// A HLO module contains one or more HLO computations. The module contains one -// "entry" computation which produces the result. The module also includes any -// embedded computations used by instructions such as "map" and "reduce". All -// computations are owned by the module. +// HloModule is the top-level unit in the HLO IR. It corresponds to a whole +// "program". Running a module, from beginning to end, is the only way to run +// an XLA program. +// +// A module contains one "entry computation"; this HloComputation is like main() +// in a C program. The result of running the module is the result of running +// this computation. +// +// A module also contains some number of "nested computations". Each nested +// computation is attached to an HloInstruction within some other computation. +// The meaning of the nested computation depends on the instruction it's +// attached to. class HloModule { public: - HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config); - // Constructor without a versioned computation handle. This constructor should // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation @@ -86,8 +90,10 @@ class HloModule { std::unique_ptr Clone(const string& suffix = "clone") const; // Performs a deep clone of the computation, by recursively cloning all - // the called computations as well. - HloComputation* DeepCloneComputation(HloComputation* computation); + // the called computations as well. If the clone context is specified, it + // will be populated with the cloned object mappings. + HloComputation* DeepCloneComputation(HloComputation* computation, + HloCloneContext* context = nullptr); // Return a pointer to the entry computation of the module.. const HloComputation* entry_computation() const { @@ -115,10 +121,6 @@ class HloModule { return config_.device_entry_computation_layout(); } - const VersionedComputationHandle& entry_computation_handle() const { - return entry_computation_handle_; - } - // Gets the computations in this module. // // Returns a view of HloComputation*s, so you can iterate over this in the @@ -177,9 +179,7 @@ class HloModule { // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; static StatusOr> CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle = - VersionedComputationHandle()); + const HloModuleProto& proto, const HloModuleConfig& module_config); // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. @@ -253,10 +253,6 @@ class HloModule { mutable std::mt19937_64 rng_{42}; mutable tensorflow::mutex rng_mutex_; - // Versioned handle of the entry computation of the module. - bool has_entry_computation_handle_ = false; - VersionedComputationHandle entry_computation_handle_; - // Unique name generator for computation and instruction names, which are // unique per module. NameUniquer computation_name_uniquer_{/*separator=*/"."}; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc new file mode 100644 index 0000000000000000000000000000000000000000..98d20315e399c6b1a3979b5d11a89ef93869f4d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +bool HasSendRecv(HloComputation* computation) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kSendDone || + instruction->opcode() == HloOpcode::kRecv || + instruction->opcode() == HloOpcode::kRecvDone) { + return true; + } + for (auto* sub_computation : instruction->called_computations()) { + if (HasSendRecv(sub_computation)) { + return true; + } + } + } + return false; +} + +StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + + const auto* xla_while = instruction; + auto* while_body_comp = xla_while->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + + if (!ShapeUtil::IsTuple(xla_while->shape()) || + while_body_root->opcode() != HloOpcode::kTuple || + HasSendRecv(while_body_comp)) { + // Only run DCE on tuple-shaped while loops where body root is Tuple, + // with no send/recv instructions. + VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); + continue; + } + + // Remove dead tuple elements. + const int64 tuple_element_count = + ShapeUtil::TupleElementCount(xla_while->shape()); + for (int64 i = 0; i < tuple_element_count; ++i) { + if (liveness->IsLive(xla_while, {i})) { + continue; + } + VLOG(1) << "WhileDCE Dead while tuple element." + << " while: " << xla_while->name() << " tuple_index: " << i; + // Transform while.body computation to make tuple element at + // 'shape_index' as simple pass-through parameter (which candidate + // be removed later by simplification pass). + HloInstruction* pass_thru_gte = while_body_comp->AddInstruction( + HloInstruction::CreateGetTupleElement( + while_body_param->shape().tuple_shapes(i), while_body_param, + i)); + // Replace while.body.root Tuple operand at 'tuple_index' with + // 'pass_thru_gte', making prior operand a dead root (to be cleaned + // up with a subsequent DCE pass). + TF_RETURN_IF_ERROR( + while_body_root->ReplaceOperandWith(i, pass_thru_gte)); + changed = true; + } + } + } + return changed; +} + +} // namespace + +StatusOr HloModuleDCE::Run(HloModule* module) { + VLOG(2) << "Before HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + std::unique_ptr liveness; + TF_ASSIGN_OR_RETURN(liveness, HloLivenessAnalysis::Run(*module)); + + // Sweep through while instructions, transforming dead while tuple element + // computations to pass through tuple values (creating dead roots in while + // body computation in the process). + TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed, + RunWhileDCE(module, liveness.get())); + + // Run HloDCE to clean up any dead code created during HloModuleDCE. + HloDCE hlo_dce; + TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module)); + + VLOG(2) << "After HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + return hlo_module_dce_changed | hlo_dce_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h new file mode 100644 index 0000000000000000000000000000000000000000..29024085c1038961ef2b3721de1ce0e8a55ccf45 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which removes dead code from computations in the module using +// HloModule-scoped analysis (HloLivenessAnalysis). +// +// Sweeps through live instructions which cross computation boundaries (kWhile), +// and removes code at dead shape indices. +// +class HloModuleDCE : public HloPassInterface { + public: + ~HloModuleDCE() override {} + tensorflow::StringPiece name() const override { return "hlo-module-dce"; } + + // Run the pass on the given module. Returns whether the module was changed + // (instructions were removed). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..363862e4905fc13a4ef07aeaac255259fc6b86ba --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -0,0 +1,371 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class HloModuleDceTest : public HloTestBase { + protected: + HloModuleDceTest() {} + + // Returns whether the given instruction exists in the given computation. + bool HasInstruction(const HloComputation& computation, + const HloInstruction* instruction) { + return std::find(computation.instructions().begin(), + computation.instructions().end(), + instruction) != computation.instructions().end(); + } + + // Returns whether the while instruction with name 'while_name' in + // 'computation' passes through its tuple element at 'tuple_index' from + // parameter to root instruction. + bool WhileBodyHasPassThroughTupleElement(const HloComputation* computation, + const string& while_name, + const int64 tuple_index) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile && + instruction->name() == while_name) { + auto* while_body_comp = instruction->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + auto* operand = while_body_root->operand(tuple_index); + if (operand->opcode() == HloOpcode::kGetTupleElement && + operand->tuple_index() == tuple_index && + operand->operand(0) == while_body_param) { + return true; + } + return false; + } + } + return false; + } +}; + +// Tests that a while with all outputs live is unmodified. +TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests a while loop with one unused output (which is used in the while loop +// body by an instruction with side-effects: rng) is unmodified. +TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], f32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1 + constant.2 = f32[] constant(1.0) + rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform + add.1 = s32[] add(get-tuple-element.2, constant.2) + ROOT tuple = (s32[], f32[]) tuple(add, add.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], f32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.3 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3) + } + ENTRY SimpleLoop { + constant.4 = s32[] constant(0) + constant.5 = f32[] constant(0.0) + tuple.1 = (s32[], f32[]) tuple(constant.4, constant.5) + while = (s32[], f32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a while loop with one dead tuple element at {1} has its while +// loop body modified to make that tuple element pass-through the while body. +TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} should now be pass-through after ModuleDCE. + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a tuple element {1} used by condition computation (which appears +// dead in while.body{1} and at while.result{1}) propgates liveness of this +// tuple element to while.body{1} and at while.result{1}. +TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + multiply = s32[] multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[]) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, constant.4) + while = (s32[], s32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} still be pass-through after ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at index {1} between +// two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6) + while.1 = (s32[], s32[3]{0}) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=0 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1 and while.2 should not have pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1 and while.2 should have pass-thru elements, + // after being modified to pass through unused tuple element {1}. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and +// while.2{1}, between two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=0 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[3]{0}, s32[]) tuple(multiply, add) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[3]{0}, s32[]) tuple(constant.6, constant.5) + while.1 = (s32[3]{0}, s32[]) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=1 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1{0} and while.2{1} should not be pass-thru. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 54c34ce116651608e6d91cdcba9c708ca3a5f75e..bf33640db16638803f4f8e6c66f35d6bb6e2c9fe 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" +#include #include #include @@ -47,13 +48,16 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { case ComputationKind::kConditionalFalse: repr += ":CONDITIONAL_FALSE"; break; + case ComputationKind::kCallFunction: + repr += ":CALL"; + break; } return repr; } /* static */ StatusOr> HloModuleGroupMetadata::Build(const std::vector& modules) { - auto metadata = absl::make_unique(modules); + auto metadata = MakeUnique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); } @@ -83,6 +87,7 @@ Status HloModuleGroupMetadata::Build() { << "Peer instruction does not match the computation kind"; TF_RETURN_IF_ERROR( AddCompanion(tracked->instruction(), peer_tracked->instruction())); + tracked_instructions_comms_[tracked->instruction()].push_back(hlo); } // Add the parents of companion instructions (they must be all of the same @@ -107,6 +112,47 @@ Status HloModuleGroupMetadata::Build() { TF_RETURN_IF_ERROR(computation->Accept(visitor)); } } + TF_RETURN_IF_ERROR(VerifyCompanionSets()); + if (VLOG_IS_ON(4)) { + DumpCollectedStats(); + } + return Status::OK(); +} + +Status HloModuleGroupMetadata::VerifyCompanionSets() const { + for (const auto& companions : companion_sets_) { + // A companion set must be composed at most of an instruction per + // device/module. + std::unordered_set devices; + for (HloInstruction* instruction : *companions) { + // Go through all the communicating instructions (send, recv) of the given + // companion, and record their device. + auto it = tracked_instructions_comms_.find(instruction); + if (it == tracked_instructions_comms_.end()) { + // Companions can be added even if they have no communicating + // instructions, if they are parent of companions. + continue; + } + std::unordered_set comm_devices; + for (HloInstruction* comm_instruction : it->second) { + auto device = GetInstructionDevice(*comm_instruction); + TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString() + << " does not have a device"; + comm_devices.insert(*device); + } + for (int64 device : comm_devices) { + if (!devices.insert(device).second) { + std::stringstream ss; + ss << "Companion set:" << std::endl; + for (HloInstruction* hlo : *companions) { + ss << " " << hlo->name() << std::endl; + } + ss << "has multiple instructions on the same device"; + return FailedPrecondition("%s", ss.str().c_str()); + } + } + } + } return Status::OK(); } @@ -194,6 +240,28 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { LOG(FATAL) << "unknown module"; } +tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( + const HloInstruction& instruction) const { + // The module group metadata can be created in both "single module, multiple + // devices" and "multiple modules, no explicit devices" fashions. + // The API returns an optional even though the current implementation always + // returns a device, to account for cases where we cannot guess a device. + // In such cases the VerifyChannelInstructions() will return proper errors. + tensorflow::gtl::optional device = + instruction.sharding_unique_device(); + if (!device) { + device = GetModuleId(instruction.parent()->parent()); + } + return device; +} + +int64 HloModuleGroupMetadata::GetDeviceModulesCount() const { + return std::count_if(modules_.begin(), modules_.end(), + [](const HloModule* module) { + return !module->config().is_host_module(); + }); +} + Status HloModuleGroupMetadata::RecordInstructions() { const auto visitor = [this](HloInstruction* hlo) -> Status { if (hlo->opcode() == HloOpcode::kWhile) { @@ -206,6 +274,9 @@ Status HloModuleGroupMetadata::RecordInstructions() { TrackedInstruction(hlo, ComputationKind::kConditionalTrue); tracked_instructions_[hlo->false_computation()] = TrackedInstruction(hlo, ComputationKind::kConditionalFalse); + } else if (hlo->opcode() == HloOpcode::kCall) { + tracked_instructions_[hlo->to_apply()] = + TrackedInstruction(hlo, ComputationKind::kCallFunction); } if (!IsChannelInstruction(hlo)) { return Status::OK(); @@ -252,20 +323,22 @@ Status HloModuleGroupMetadata::RecordInstructions() { TF_RETURN_IF_ERROR(computation->Accept(visitor)); } } + VLOG(2) << "Created " << channels_.size() << " channels"; return Status::OK(); } Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, HloInstruction* instruction2) { TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile || - instruction1->opcode() == HloOpcode::kConditional); + instruction1->opcode() == HloOpcode::kConditional || + instruction1->opcode() == HloOpcode::kCall); VLOG(2) << "adding as companions:" << instruction1->ToString() << " and " << instruction2->ToString(); if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - absl::make_unique>()); + tensorflow::MakeUnique>()); auto companion_set = companion_sets_.back().get(); companion_set->insert(instruction1); companion_set->insert(instruction2); @@ -313,44 +386,46 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { if (!ShapeUtil::Compatible(send_shape, recv_shape)) { return FailedPrecondition("send/recv shapes do not match"); } - const HloModule* send_module = channel.send->parent()->parent(); - const HloModule* send_done_module = channel.send_done->parent()->parent(); - if (send_module != send_done_module) { + auto send_device = GetInstructionDevice(*channel.send); + auto send_done_device = GetInstructionDevice(*channel.send_done); + if (!send_device) { + return FailedPrecondition("send instruction must have a device: %s", + channel.send->ToString().c_str()); + } + if (!send_done_device) { + return FailedPrecondition("send_done instruction must have a device: %s", + channel.send_done->ToString().c_str()); + } + if (*send_device != *send_done_device) { return FailedPrecondition( "send and send-done (channel=%lld) must be on the same device: %lld " "vs. %lld", - channel.id, GetModuleId(send_module), GetModuleId(send_done_module)); + channel.id, *send_device, *send_done_device); + } + auto recv_device = GetInstructionDevice(*channel.recv); + auto recv_done_device = GetInstructionDevice(*channel.recv_done); + if (!recv_done_device) { + return FailedPrecondition("recv_done instruction must have a device: %s", + channel.recv_done->ToString().c_str()); } - const HloModule* recv_module = channel.recv->parent()->parent(); - const HloModule* recv_done_module = channel.recv_done->parent()->parent(); - if (recv_module != recv_done_module) { + if (*recv_device != *recv_done_device) { return FailedPrecondition( "recv and recv-done (channel=%lld) must be on the same device: %lld " "vs. %lld", - channel.id, GetModuleId(recv_module), GetModuleId(recv_done_module)); + channel.id, *recv_device, *recv_done_device); } - if (send_module == recv_module) { + if (*send_device == *recv_device) { return FailedPrecondition( "send and recv (channel=%lld) must be on different devices: %lld", - channel.id, GetModuleId(send_module)); + channel.id, *send_device); } } - // Check if channel instructions are used only in allowed computations. - const auto allowed = [this](HloInstruction* hlo) { - HloComputation* computation = hlo->parent(); - const HloModule* module = computation->parent(); - if (module->entry_computation() == computation || - tracked_instructions_.count(computation) > 0) { - return true; - } - return false; - }; for (const Channel& channel : channels_) { - if (!allowed(channel.send) || !allowed(channel.send_done) || - !allowed(channel.recv) || !allowed(channel.recv_done)) { - return FailedPrecondition("channel is used in disallowed computation"); - } + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done)); } // Check if the nest levels match for each channel. for (const Channel& channel : channels_) { @@ -368,4 +443,47 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { return Status::OK(); } +Status HloModuleGroupMetadata::CheckCommunicatingInstruction( + HloInstruction* instruction) const { + HloComputation* computation = instruction->parent(); + const HloModule* module = computation->parent(); + if (module->entry_computation() == computation || + tracked_instructions_.count(computation) > 0) { + return Status::OK(); + } + return FailedPrecondition("channel is used in disallowed computation"); +} + +void HloModuleGroupMetadata::DumpCollectedStats() const { + std::map, int64> communication_histogram; + for (auto& channel : channels_) { + auto from_device = GetInstructionDevice(*channel.send); + auto to_device = GetInstructionDevice(*channel.recv); + LOG(INFO) << "Channel " << channel.id << ": from_device=" << *from_device + << " to_device=" << *to_device << " send=" << channel.send->name() + << " send_done=" << channel.send_done->name() + << " recv=" << channel.recv->name() + << " recv_done=" << channel.recv_done->name(); + communication_histogram[std::pair(*from_device, + *to_device)] += 1; + } + for (auto& fromto_count : communication_histogram) { + LOG(INFO) << "From " << fromto_count.first.first << " to " + << fromto_count.first.second << ": " << fromto_count.second; + } + for (auto& companion_set : companion_sets_) { + LOG(INFO) << "Companion set:"; + for (HloInstruction* instruction : *companion_set) { + LOG(INFO) << " " << instruction->name(); + } + } + for (auto& instruction_comm : tracked_instructions_comms_) { + LOG(INFO) << "Communicating instruction " << instruction_comm.first->name(); + for (HloInstruction* instruction : instruction_comm.second) { + auto device = GetInstructionDevice(*instruction); + LOG(INFO) << " " << instruction->name() << " on device " << *device; + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index c48a7ab0b59269474f7406ef24a249355528e085..ffde3a332dfc141ca928a44cfdf4686900e9f47b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -60,6 +61,7 @@ class HloModuleGroupMetadata { kWhileBody, kConditionalTrue, kConditionalFalse, + kCallFunction, }; // Tracks the instruction mapped to a given computation, and the computation @@ -147,6 +149,15 @@ class HloModuleGroupMetadata { // the module in the module vector. int64 GetModuleId(const HloModule* module) const; + // Retrieves the device an instruction is assigned to. Either from the + // sharding information, or from the ordinal of the module the instruction + // is in. + tensorflow::gtl::optional GetInstructionDevice( + const HloInstruction& instruction) const; + + // Returns the number of modules for devices (excluding the host module). + int64 GetDeviceModulesCount() const; + // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. @@ -202,6 +213,15 @@ class HloModuleGroupMetadata { Status AddCompanion(HloInstruction* instruction1, HloInstruction* instruction2); + // Checks whether a communicating instruction is placed in a valid position + // within the graph. + Status CheckCommunicatingInstruction(HloInstruction* instruction) const; + + // Performs a consistency check on the companion sets built for the input + // modules. Check that a companion set does not include instructions from the + // same module/device. + Status VerifyCompanionSets() const; + // Retrieves a pointer to the stored TrackedInstruction associated with a // tracked computation, or nullptr in case such computation is not tracked. const TrackedInstruction* GetTrackedInstruction( @@ -210,6 +230,9 @@ class HloModuleGroupMetadata { return it != tracked_instructions_.end() ? &it->second : nullptr; } + // Dump all the collected module group statistics to the logs. + void DumpCollectedStats() const; + // List of all companion instructions sets in the module. std::vector>> companion_sets_; @@ -221,6 +244,11 @@ class HloModuleGroupMetadata { tensorflow::gtl::FlatMap tracked_instructions_; + // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of + // communicating instructions within the proper called computation(s). + tensorflow::gtl::FlatMap> + tracked_instructions_comms_; + // All channels in the module. std::vector channels_; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 289c96b0a7b90c5f8a122cd3fc327a5762099106..5a0d1e264eb5095ff53721416ebcf4842a063f97 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -289,7 +290,7 @@ HloModuleGroupUtil::ComputeReachability( TF_RETURN_IF_ERROR( VisitTopologicalOrder(&visit_states, visit_function, root)); } - auto reachability = absl::make_unique(post_order); + auto reachability = MakeUnique(post_order); for (HloInstruction* hlo : post_order) { reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index ca763076a16af1150a8623fb7dbf22c46a5ca263..a35546f5f41b149d119ee141fd734da8bfd055b2 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -69,16 +69,19 @@ namespace xla { V(kCrossReplicaSum, "cross-replica-sum") \ V(kCustomCall, "custom-call") \ V(kDivide, "divide") \ + V(kDomain, "domain") \ V(kDot, "dot") \ V(kDynamicSlice, "dynamic-slice") \ V(kDynamicUpdateSlice, "dynamic-update-slice") \ V(kEq, "equal-to", kHloOpcodeIsComparison) \ V(kExp, "exponential") \ + V(kExpm1, "exponential-minus-one") \ V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kHostCompute, "host-compute") \ @@ -87,6 +90,7 @@ namespace xla { V(kIsFinite, "is-finite") \ V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ V(kLog, "log") \ + V(kLog1p, "log-plus-one") \ V(kAnd, "and") \ V(kNot, "not") \ V(kOr, "or") \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index cd2ce5c69f030c65b889d67e082a3677b8739ddb..774345124b4ad62e35d9423a23f1dbaa28e44d80 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kConcatenate: case HloOpcode::kFusion: case HloOpcode::kMap: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index e89d94bede6c437ca1131a1b1b0098390d58c0d9..dcd4725fe78e8b9b5d14437e964cb5aaf1664117 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -170,10 +169,10 @@ bool HloOrdering::UseIsBeforeValueDefinition( // is before the def if the instruction allows buffer sharing (in place // computation). if (use.instruction == value.defining_instruction() && - CanShareOperandBufferWithUser( + dataflow.CanShareOperandBufferWithUser( use.instruction->mutable_operand(use.operand_number), use.operand_index, value.defining_instruction(), - value.defining_index(), dataflow)) { + value.defining_index())) { VLOG(4) << " use is value def, and instruction can share use buffer"; return true; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 37a7fbad97cea2f34798efecc2489e57d1374f35..cfe5dace05ac03f1573f90b2ce664c94837837b4 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -310,7 +310,7 @@ ENTRY while.v11 { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); DependencyHloOrdering ordering(module.get()); ordering.ToString(); // Shouldn't crash. } @@ -347,7 +347,7 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN(auto dataflow, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); DependencyHloOrdering ordering(module.get()); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc similarity index 86% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.cc rename to tensorflow/compiler/xla/service/hlo_parser.cc index 156a06c596c3f1550213cb5ac5d11834a80b7181..4aa44062922e49b77529fd384a958a277e264d53 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -24,18 +26,17 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { -namespace tools { namespace { -using tensorflow::StringPiece; -using tensorflow::gtl::optional; -using tensorflow::str_util::Join; -using tensorflow::str_util::Split; -using tensorflow::str_util::SplitAndParseAsInts; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using ::tensorflow::StringPiece; +using ::tensorflow::gtl::optional; +using ::tensorflow::str_util::Join; +using ::tensorflow::str_util::Split; +using ::tensorflow::str_util::SplitAndParseAsInts; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; const double kF16max = 65504; @@ -56,6 +57,11 @@ class HloParser { // Returns the error information. string GetError() const { return Join(error_, "\n"); } + // Stand alone parsing utils for various aggregate data types. + StatusOr ParseShardingOnly(); + StatusOr ParseWindowOnly(); + StatusOr ParseConvolutionDimensionNumbersOnly(); + private: // ParseXXX returns false if an error occurred. bool ParseHloModule(); @@ -78,11 +84,15 @@ class HloParser { // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. - bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(double value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal); + bool SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal); + bool SetValueInLiteral(double value, tensorflow::int64 linear_index, + Literal* literal); + bool SetValueInLiteral(bool value, tensorflow::int64 linear_index, + Literal* literal); template - bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, + bool SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal); bool ParseOperands(std::vector* operands); @@ -94,9 +104,15 @@ class HloParser { // Describes the start, limit, and stride on every dimension of the operand // being sliced. struct SliceRanges { - std::vector starts; - std::vector limits; - std::vector strides; + std::vector starts; + std::vector limits; + std::vector strides; + }; + + // The data parsed for the kDomain instruction. + struct DomainData { + std::unique_ptr entry_metadata; + std::unique_ptr exit_metadata; }; // Types of attributes. @@ -117,6 +133,7 @@ class HloParser { kMetadata, kFusionKind, kDistribution, + kDomain, }; struct AttrConfig { @@ -164,21 +181,27 @@ class HloParser { bool ParseComputationName(HloComputation** value); // Parses a list of names and finds the corresponding hlo instructions. bool ParseInstructionNames(std::vector* instructions); - bool ParseWindow(Window* window); + // Pass expect_outer_curlies == true when parsing a Window in the context of a + // larger computation. Pass false when parsing a stand-alone Window string. + bool ParseWindow(Window* window, bool expect_outer_curlies); bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); bool ParsePaddingConfig(PaddingConfig* padding); bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + // Parses the metadata behind a kDOmain instruction. + bool ParseDomain(DomainData* domain); + // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. - bool ParseDxD(const string& name, std::vector* result); + bool ParseDxD(const string& name, std::vector* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. - bool ParseWindowPad(std::vector>* pad); + bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); bool ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, std::vector* result); + const TokKind delim, + std::vector* result); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -190,7 +213,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); - bool ParseInt64(int64* result); + bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); @@ -384,6 +407,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } + instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } @@ -447,7 +471,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { - int64 parameter_number; + tensorflow::int64 parameter_number; if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number) || @@ -481,10 +505,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -561,11 +587,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands)); + HloInstruction::CreateCrossReplicaSum(shape, operands, *to_apply)); break; } case HloOpcode::kReshape: { @@ -577,6 +606,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateReshape(shape, operands[0])); break; } + case HloOpcode::kGenerateToken: { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateGenerateToken(operands)); + break; + } case HloOpcode::kTuple: { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; @@ -600,7 +637,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kRecv: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { @@ -611,7 +648,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kRecvDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -625,7 +662,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSend: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -636,7 +673,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSendDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -650,7 +687,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGetTupleElement: { - optional index; + optional index; attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -708,7 +745,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kFft: { optional fft_type; - optional> fft_length; + optional> fft_length; attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, &fft_length}; @@ -721,7 +758,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kBroadcast: { - optional> broadcast_dimensions; + optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &broadcast_dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -733,7 +770,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConcatenate: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || @@ -748,6 +785,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional to_apply; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; + optional> dimensions; + attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, + &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -759,7 +799,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; - optional> dimensions_to_reduce; + optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -772,7 +812,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReverse: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -816,7 +856,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDynamicSlice: { - optional> dynamic_slice_sizes; + optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -840,7 +880,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kTranspose: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -854,7 +894,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormTraining: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/3) || @@ -870,7 +910,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormInference: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -887,7 +927,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormGrad: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -958,8 +998,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReducePrecision: { - optional exponent_bits; - optional mantissa_bits; + optional exponent_bits; + optional mantissa_bits; attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, &exponent_bits}; attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, @@ -1004,7 +1044,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kHostCompute: { optional channel_name; - optional cost_estimate_ns; + optional cost_estimate_ns; attrs["channel_name"] = {/*required=*/true, AttrTy::kString, &channel_name}; attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, @@ -1017,16 +1057,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDot: { - optional> lhs_contracting_dims; + optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; - optional> rhs_contracting_dims; + optional> rhs_contracting_dims; attrs["rhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; - optional> lhs_batch_dims; + optional> lhs_batch_dims; attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &lhs_batch_dims}; - optional> rhs_batch_dims; + optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; @@ -1058,20 +1098,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGather: { - optional> output_window_dims; + optional> output_window_dims; attrs["output_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; - optional> elided_window_dims; + optional> elided_window_dims; attrs["elided_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; - optional> gather_dims_to_operand_dims; + optional> gather_dims_to_operand_dims; attrs["gather_dims_to_operand_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, &gather_dims_to_operand_dims}; - optional index_vector_dim; + optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional> window_bounds; + optional> window_bounds; attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, &window_bounds}; @@ -1091,12 +1131,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, dim_numbers, *window_bounds)); break; } + case HloOpcode::kDomain: { + DomainData domain; + attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateDomain( + shape, operands[0], std::move(domain.entry_metadata), + std::move(domain.exit_metadata))); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } - instruction->set_name(name); + instruction->SetAndSanitizeName(name); + if (instruction->name() != name) { + return Error(name_loc, + StrCat("illegal instruction name: ", name, + "; suggest renaming to: ", instruction->name())); + } // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { @@ -1116,7 +1173,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_metadata(*metadata); } if (backend_config) { - instruction->set_backend_config(std::move(*backend_config)); + instruction->set_raw_backend_config_string(std::move(*backend_config)); } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1167,8 +1224,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; - std::vector devices; - std::vector tile_assignment_dimensions; + std::vector devices; + std::vector tile_assignment_dimensions; Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { @@ -1195,7 +1252,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } do { - int64 dim; + tensorflow::int64 dim; if (!ParseInt64(&dim)) { return false; } @@ -1207,7 +1264,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } do { - int64 device; + tensorflow::int64 device; if (!ParseInt64(&device)) { return false; } @@ -1266,10 +1323,10 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); *sharding->mutable_tile_shape() = tile_shape; - for (int64 dim : tile_assignment_dimensions) { + for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } - for (int64 device : devices) { + for (tensorflow::int64 device : devices) { sharding->add_tile_assignment_devices(device); } } @@ -1278,6 +1335,34 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return true; } +// domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' +// 'exit=' exit_sharding '}' +bool HloParser::ParseDomain(DomainData* domain) { + std::unordered_map attrs; + optional kind; + optional entry_sharding; + optional exit_sharding; + attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind}; + attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding}; + attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding}; + if (!ParseSubAttributes(attrs)) { + return false; + } + if (*kind == ShardingMetadata::KindName()) { + auto entry_sharding_ptr = MakeUnique( + HloSharding::FromProto(*entry_sharding).ValueOrDie()); + auto exit_sharding_ptr = MakeUnique( + HloSharding::FromProto(*exit_sharding).ValueOrDie()); + domain->entry_metadata = + MakeUnique(std::move(entry_sharding_ptr)); + domain->exit_metadata = + MakeUnique(std::move(exit_sharding_ptr)); + } else { + return TokenError(StrCat("unsupported domain kind: ", *kind)); + } + return true; +} + // '{' name+ '}' bool HloParser::ParseInstructionNames( std::vector* instructions) { @@ -1304,40 +1389,50 @@ bool HloParser::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, +bool HloParser::SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case S8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(double value, int64 linear_index, +bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case F16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, literal); case BF16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case F32: return SetValueInLiteralHelper(value, linear_index, literal); case F64: @@ -1348,7 +1443,7 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index, } } -bool HloParser::SetValueInLiteral(bool value, int64 linear_index, +bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { @@ -1361,7 +1456,8 @@ bool HloParser::SetValueInLiteral(bool value, int64 linear_index, } template -bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, +bool HloParser::SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal) { // Check that linear_index is in range. if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { @@ -1473,7 +1569,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape) { - const int64 rank = ShapeUtil::Rank(shape); + const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; } @@ -1481,8 +1577,8 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // Create a literal with the given shape in default layout. *literal = Literal::CreateFromDimensions(shape.element_type(), AsInt64Slice(shape.dimensions())); - int64 nest_level = 0; - int64 linear_index = 0; + tensorflow::int64 nest_level = 0; + tensorflow::int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, // when we are parsing the 2nd '{' (right before '1'), we are seeing a @@ -1490,14 +1586,14 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // the first '}' (right after '3'), it means the sub-array ends, and the // sub-array is supposed to contain exactly 3 elements, so check if // elems_seen_per_dim[1] is 3. - std::vector elems_seen_per_dim(rank); + std::vector elems_seen_per_dim(rank); auto get_index_str = [&elems_seen_per_dim](int dim) -> string { - std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), - elems_seen_per_dim.begin() + dim); + std::vector elems_seen_until_dim( + elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", Join(elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); + [](string* out, const tensorflow::int64& num_elems) { + StrAppend(out, num_elems - 1); }), "]"); }; @@ -1573,7 +1669,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { LocTy loc = lexer_.GetLoc(); - int64 value; + tensorflow::int64 value; if (!ParseInt64(&value)) { return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); @@ -1613,29 +1709,29 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, switch (shape.element_type()) { case PRED: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F32: return ParseSparseLiteralHelper(literal, shape); case BF16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F64: return ParseSparseLiteralHelper(literal, shape); default: @@ -1648,9 +1744,9 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, template bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, const Shape& shape) { - std::vector index; + std::vector index; - int64 rank = ShapeUtil::Rank(shape); + tensorflow::int64 rank = ShapeUtil::Rank(shape); *literal = MakeUnique(shape); @@ -1668,7 +1764,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, LocTy index_loc = lexer_.GetLoc(); index.clear(); if (lexer_.GetKind() == TokKind::kInt) { - int64 single_index = lexer_.GetInt64Val(); + tensorflow::int64 single_index = lexer_.GetInt64Val(); lexer_.Lex(); if (rank != 1) { return Error( @@ -1701,7 +1797,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, value = static_cast(lexer_.GetKind() == TokKind::kw_true); lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { - int64 value_s64; + tensorflow::int64 value_s64; if (!ParseInt64(&value_s64)) { return Error(value_loc, StrCat("expects integer for primitive type: ", @@ -1874,23 +1970,24 @@ bool HloParser::ParseAttributeHelper( LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { case AttrTy::kInt64: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - static_cast*>(attr_out_ptr)->emplace(result); + static_cast*>(attr_out_ptr) + ->emplace(result); return true; } case AttrTy::kInt32: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - if (result != static_cast(result)) { + if (result != static_cast(result)) { return Error(attr_loc, "value out of range for int32"); } - static_cast*>(attr_out_ptr) - ->emplace(static_cast(result)); + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); return true; } case AttrTy::kFloat: { @@ -1924,7 +2021,7 @@ bool HloParser::ParseAttributeHelper( } case AttrTy::kWindow: { Window result; - if (!ParseWindow(&result)) { + if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { return false; } static_cast*>(attr_out_ptr)->emplace(result); @@ -1966,12 +2063,12 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kBracedInt64List: { - std::vector result; + std::vector result; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &result)) { return false; } - static_cast>*>(attr_out_ptr) + static_cast>*>(attr_out_ptr) ->emplace(result); return true; } @@ -2016,6 +2113,9 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kDomain: { + return ParseDomain(static_cast(attr_out_ptr)); + } } }(); if (!success) { @@ -2042,9 +2142,10 @@ bool HloParser::ParseComputationName(HloComputation** value) { // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' // The subattributes can appear in any order. 'size=' is required, others are // optional. -bool HloParser::ParseWindow(Window* window) { +bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) { LocTy loc = lexer_.GetLoc(); - if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { + if (expect_outer_curlies && + !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { return false; } @@ -2054,7 +2155,9 @@ bool HloParser::ParseWindow(Window* window) { std::vector lhs_dilate; std::vector rhs_dilate; std::vector rhs_reversal; - while (lexer_.GetKind() != TokKind::kRbrace) { + const auto end_token = + expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof; + while (lexer_.GetKind() != end_token) { LocTy attr_loc = lexer_.GetLoc(); string field_name; if (!ParseAttributeName(&field_name)) { @@ -2118,7 +2221,8 @@ bool HloParser::ParseWindow(Window* window) { window->mutable_dimensions(i)->set_window_reversal( rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); } - return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); + return !expect_outer_curlies || + ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); } // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. @@ -2142,7 +2246,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( << str; } - const int64 rank = lhs_rhs_out[0].length(); + const tensorflow::int64 rank = lhs_rhs_out[0].length(); if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); @@ -2256,7 +2360,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { return false; } - std::vector> ranges; + std::vector> ranges; if (lexer_.GetKind() == TokKind::kRbrace) { // empty return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); @@ -2290,7 +2394,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { // ::= int64_val (delim int64_val)* bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, - std::vector* result) { + std::vector* result) { if (!ParseToken(start, StrCat("expects an int64 list starting with ", TokKindToString(start)))) { return false; @@ -2299,7 +2403,7 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, // empty } else { do { - int64 i; + tensorflow::int64 i; if (!ParseInt64(&i)) { return false; } @@ -2416,7 +2520,8 @@ bool HloParser::ParseString(string* result) { return true; } -bool HloParser::ParseDxD(const string& name, std::vector* result) { +bool HloParser::ParseDxD(const string& name, + std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { return Error(loc, @@ -2424,7 +2529,7 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { } // 1D if (lexer_.GetKind() == TokKind::kInt) { - int64 number; + tensorflow::int64 number; if (!ParseInt64(&number)) { return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); } @@ -2444,7 +2549,8 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { return TokenError("expects token type kInt or kDxD"); } -bool HloParser::ParseWindowPad(std::vector>* pad) { +bool HloParser::ParseWindowPad( + std::vector>* pad) { LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { return Error(loc, "sub-attribute 'pad=' already exists"); @@ -2455,7 +2561,7 @@ bool HloParser::ParseWindowPad(std::vector>* pad) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (int i = 0; i < padding_str.size(); i++) { - std::vector low_high; + std::vector low_high; if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || low_high.size() != 2) { return Error(loc, @@ -2479,7 +2585,7 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (const auto& padding_dim_str : padding_str) { - std::vector padding_dim; + std::vector padding_dim; if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, @@ -2501,7 +2607,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { optional op_type; optional op_name; optional source_file; - optional source_line; + optional source_line; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; @@ -2588,7 +2694,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } -bool HloParser::ParseInt64(int64* result) { +bool HloParser::ParseInt64(tensorflow::int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { return TokenError("expects integer"); @@ -2671,10 +2777,48 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation, return true; } +StatusOr HloParser::ParseShardingOnly() { + lexer_.Lex(); + OpSharding op_sharding; + if (!ParseSharding(&op_sharding)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after sharding"); + } + return HloSharding::FromProto(op_sharding); +} + +StatusOr HloParser::ParseWindowOnly() { + lexer_.Lex(); + Window window; + if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after window"); + } + return window; +} + +StatusOr +HloParser::ParseConvolutionDimensionNumbersOnly() { + lexer_.Lex(); + ConvolutionDimensionNumbers dnums; + if (!ParseConvolutionDimensionNumbers(&dnums)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after convolution dnums"); + } + return dnums; +} + } // namespace -StatusOr> Parse(StringPiece str, - const HloModuleConfig& config) { +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); @@ -2682,10 +2826,29 @@ StatusOr> Parse(StringPiece str, return parser.ConsumeHloModule(); } -StatusOr> Parse(StringPiece str) { +StatusOr> ParseHloString( + tensorflow::StringPiece str) { + HloModuleConfig config; + return ParseHloString(str, config); +} + +StatusOr ParseSharding(tensorflow::StringPiece str) { HloModuleConfig config; - return Parse(str, config); + HloParser parser(str, config); + return parser.ParseShardingOnly(); +} + +StatusOr ParseWindow(tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseWindowOnly(); +} + +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseConvolutionDimensionNumbersOnly(); } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h similarity index 52% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.h rename to tensorflow/compiler/xla/service/hlo_parser.h index 2f97a2b9b19d0cdb64a2869913da62c55e14c1d5..3f3a51215e34bbdd667f1cb20d0ae968e0ce5efd 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -13,30 +13,47 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -namespace tools { + +// For details about the syntax accepted by this parser, see +// g3doc/hlo_parser.md. // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with the given config. -StatusOr> Parse(tensorflow::StringPiece str, - const HloModuleConfig& config); +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config); // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with default config. -StatusOr> Parse(tensorflow::StringPiece str); +StatusOr> ParseHloString( + tensorflow::StringPiece str); + +// Parses the result of HloSharding::ToString(), e.g. "{replicated}". +StatusOr ParseSharding(tensorflow::StringPiece str); + +// Parses the result of window_util::ToString(const Window&). +StatusOr ParseWindow(tensorflow::StringPiece str); + +// Parses the result of ConvolutionDimensionNumbersToString(), e.g. +// "b0f_0io->b0f". +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str); + +// ParseHloString sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string. +StatusOr ParseSharding(tensorflow::StringPiece str); -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc similarity index 88% rename from tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc rename to tensorflow/compiler/xla/service/hlo_parser_test.cc index e100d8cda14eabbec3942bf442aa99cc04daada4..1c5a47c8755e7df37bd9a77be7ca9b9505b35e71 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -13,19 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { -namespace tools { + namespace { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; struct TestData { string test_name; @@ -233,6 +234,17 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}} } +)" +}, +{ +"DomainParsing", +R"(HloModule DomainParsing_module + +ENTRY %DomainParsing (v1: f32[]) -> f32[] { + %v1 = f32[] parameter(0) + ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} +} + )" }, // int32 result = 0; @@ -753,7 +765,7 @@ add_F32.v3 { ENTRY MapBinaryAdder.v3 { param0 = f32[4]{0} parameter(0) param1 = f32[4]{0} parameter(1) - ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3 + ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3 } )" @@ -886,6 +898,24 @@ ENTRY Gather { ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} } +)" +}, +// cross-replica-sum +{ +"CrossReplicaSum", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + ROOT crs = f32[8]{0} cross-replica-sum(input), to_apply=add +} + )" }, }); @@ -900,12 +930,12 @@ class HloParserTest : public ::testing::Test, << "'" << s << "' does not contain '" << expected << "'"; } - // Expects "ToString(Parse(string)) == string", that is, parses the string, - // asserts that it succeeded, stringifies the parsed module, and checks that - // the it equals the original string. + // Expects "ToString(ParseHloString(string)) == string", that is, parses the + // string, asserts that it succeeded, stringifies the parsed module, and + // checks that the it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString( HloPrintOptions().set_print_large_constants(true))); @@ -916,7 +946,7 @@ class HloParserShortTest : public HloParserTest { protected: void ExpectEqualShort() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); @@ -937,14 +967,14 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, TEST_F(HloParserTest, Empty) { const string original = ""; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, Garbage) { const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOpcode) { @@ -957,8 +987,8 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongShape) { @@ -969,8 +999,8 @@ ENTRY %blabla (x: g32[]) -> g32[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOperandsSize) { @@ -982,8 +1012,8 @@ ENTRY %blabla (x: f32[]) -> pred[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, OperandNotFound) { @@ -993,8 +1023,8 @@ ENTRY %blabla (x: f32[]) -> pred[] { %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, MoreConstants) { @@ -1008,7 +1038,7 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // Constant instructions have no name. The string will be parsed successfully // but the constant names will not be exactly the same. @@ -1019,12 +1049,12 @@ TEST_F(HloParserTest, ConfigurationField) { ENTRY %configuration_test() -> s32[] { %constant = s32[] constant(42), backend_config="foo bar" })"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ("foo bar", result.ValueOrDie() ->entry_computation() ->root_instruction() - ->backend_config()); + ->raw_backend_config_string()); } TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { @@ -1035,8 +1065,8 @@ ENTRY %some_2 () -> f32[2] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 1, but sees larger"); } @@ -1049,8 +1079,8 @@ ENTRY %some_2x3 () -> f32[2,3] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 2, but sees 1"); } @@ -1063,8 +1093,8 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects 3 elements in the [0]th element"); } @@ -1078,8 +1108,8 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "is out of range for literal's primitive type F16"); } @@ -1092,7 +1122,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // The string will be parsed successfully but the output strings are not // exactly the same, because "3e2" is parsed into value 300 and will be @@ -1110,7 +1140,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, InvalidDimLabels) { @@ -1126,17 +1156,18 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; + ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); + ExpectHasSubstr( - Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), - "expects dim labels pattern"); - - ExpectHasSubstr(Parse(tensorflow::strings::StrCat( - prefix, ",dim_labels=010_1100->010", suffix)) - .status() - .error_message(), - "must have the same rank"); + "must have the same rank"); } TEST_F(HloParserTest, UnexpectedAttribute) { @@ -1151,7 +1182,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "unexpected attribute \"calls\""); } @@ -1167,7 +1198,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "attribute channel_id is expected but not seen"); } @@ -1183,7 +1214,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "'done' is not defined"); } @@ -1196,7 +1227,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { @@ -1210,7 +1241,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects padding_low and padding_high separated by '_'"); } @@ -1222,7 +1253,7 @@ ENTRY %test_comma.v4 () -> f32[] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) { @@ -1232,7 +1263,7 @@ ENTRY %CustomCall () -> f32[1] { %constant = f32[1]{0} constant({12345}) ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "Shape of computation CustomCall, f32[1], is not compatible " "with that of its root instruction foo, f32[1,2,3]"); } @@ -1251,7 +1282,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); @@ -1274,7 +1305,7 @@ c1 { c2 { const2 = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2"); } @@ -1285,7 +1316,7 @@ ENTRY consts { first = f32[1]{0} constant({12345}) last = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ( module.ValueOrDie()->entry_computation()->root_instruction()->name(), @@ -1300,7 +1331,7 @@ ENTRY c1 { ENTRY c2 { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects only one ENTRY"); } @@ -1310,25 +1341,10 @@ ENTRY consts { ROOT const1 = f32[1]{0} constant({12345}) ROOT const2 = f32[1]{0} constant({12345}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "one computation should have only one ROOT"); } -TEST_F(HloParserTest, InstructionExists) { - const string original = R"(HloModule comp_exists -c1 { - instr = f32[1]{0} constant({12345}) -} -c2 { - instr = f32[1]{0} constant({67890}) -})"; - - ExpectHasSubstr(Parse(original).status().error_message(), - R"(was parsing 3:3: error: instruction previously defined here - instr = f32[1]{0} constant({12345}) - ^)"); -} - TEST_F(HloParserTest, ComputationExists) { const string original = R"(HloModule comp_exists comp { @@ -1337,12 +1353,52 @@ comp { comp { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), R"(was parsing 2:1: error: computation previously defined here comp { ^)"); } +TEST_F(HloParserTest, CrossComputationLookup) { + const string original = R"(HloModule cross_computation_lookup: +tcalla (a: (s32[], s32[])) -> (s32[], s32[]) { + ROOT aparam = (s32[], s32[]) parameter(0) +} + +tcallb (b: (s32[], s32[])) -> s32[] { + rparam = (s32[], s32[]) parameter(0) + ROOT gte0 = s32[] get-tuple-element(aparam), index=0 +} + +ENTRY entry { + param = (s32[], s32[]) parameter(0) + call0 = (s32[], s32[]) call(param), to_apply=tcalla + ROOT call1 = s32[] call(param), to_apply=tcallb +})"; + ExpectHasSubstr( + ParseHloString(original).status().error_message(), + "was parsing 8:39: error: instruction does not exist: aparam"); +} + +TEST_F(HloParserTest, ParseSharding) { + const string original = "{maximal device=42}"; + TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); + EXPECT_EQ(sharding.ToString(), original); +} + +TEST_F(HloParserTest, ParseWindow) { + Window original = window_util::MakeWindow({1, 2, 3}); + TF_ASSERT_OK_AND_ASSIGN(Window parsed, + ParseWindow(window_util::ToString(original))) + EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed)); +} + +TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { + const string original = "b0f_0io->b0f"; + TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums, + ParseConvolutionDimensionNumbers(original)); + EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); +} + } // namespace -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 8e167633bb13476301fa0c4afa0b123c9b47e40d..4738e46f8aeb96a4c25d04b3246bd21f644fe3ea 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -33,17 +33,27 @@ bool HloReachabilityMap::SetReachabilityToUnion( const HloInstruction* instruction) { BitVector& bit_vector = GetBitVector(instruction); tmp_bit_vector_ = bit_vector; + SetReachabilityToUnionHelper(inputs, instruction, &bit_vector); + return bit_vector != tmp_bit_vector_; +} +void HloReachabilityMap::FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction) { + SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction)); +} + +void HloReachabilityMap::SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { - bit_vector.SetToZero(); + bit_vector->SetToZero(); } - bit_vector.Set(GetIndex(instruction)); + bit_vector->Set(GetIndex(instruction)); for (const HloInstruction* input : inputs) { - bit_vector.OrWith(GetBitVector(input)); + bit_vector->OrWith(GetBitVector(input)); } - - return bit_vector != tmp_bit_vector_; } void HloReachabilityMap::SetReachable(const HloInstruction* a, diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 553ec11f6f9a2997ab7113f9b8241e04c7fe20d5..69bb2b3cee6dafe058c45b4e74e93401bea2cfc9 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -57,6 +57,11 @@ class HloReachabilityMap { tensorflow::gtl::ArraySlice inputs, const HloInstruction* instruction); + // As above, but faster because it does not check if the reachability changed. + void FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction); + // Sets entry so that IsReachable(a, b) will return true // // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency @@ -133,6 +138,11 @@ class HloReachabilityMap { return bit_vectors_[GetIndex(instruction)]; } + // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. + void SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector); + // Return the index of the given instruction. The value is used to index into // the vector of BitVectors and the BitVectors themselves. int GetIndex(const HloInstruction* instruction) const { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index b171d41a31ed23f0886e7363289ea56c92216572..bd1d9935bd37ff71064a1f8f431b2ddf9c7c789d 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -72,6 +71,20 @@ bool IsRematerializable(const HloInstruction* instruction) { } } +// Checks whether an instruction can be rematerialized, by looking up the +// cache before, and eventually calling the IsRematerializable() API. +bool CanBeRematerialized( + const HloInstruction* instruction, + tensorflow::gtl::FlatMap* remat_able) { + auto it = remat_able->find(instruction); + if (it != remat_able->end()) { + return it->second; + } + bool rematerializable = IsRematerializable(instruction); + (*remat_able)[instruction] = rematerializable; + return rematerializable; +} + // Type holding a unique identifier for each Buffer object. using BufferId = int64; using BufferIdList = tensorflow::gtl::InlinedVector; @@ -274,9 +287,8 @@ ItemList GetUsers(const InstructionList& instruction_list, for (const BufferAlias& buffer_alias : points_to_analysis.GetBufferAliases(*logical_buffer)) { for (const HloInstruction* user : buffer_alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(buffer_alias.instruction(), - buffer_alias.index(), user, - points_to_analysis)) { + if (points_to_analysis.DoesNotUseOperandBuffer( + buffer_alias.instruction(), buffer_alias.index(), user)) { // The alias may be an operand of 'user', but the LogicalBuffer cannot // possibly be used by the instruction so ignore 'user'. This is the // case, for example, for the tuple element buffers in a GetTupleElement @@ -845,9 +857,10 @@ int64 RematerializationCost(const HloInstruction* instruction, // candidate which reduce memory use at the program point of the current // instruction as indicated by memory_tracker. nullptr is returned if no // candidate can be found. -Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, - const InstructionList& instruction_list, - int64 memory_limit_bytes) { +Item* PickRematerializationCandidate( + const MemoryUsageTracker& memory_tracker, + const InstructionList& instruction_list, int64 memory_limit_bytes, + tensorflow::gtl::FlatMap* remat_able) { Item* best_item = nullptr; int64 best_cost = 0; @@ -871,8 +884,7 @@ Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, << " is excluded from rematerialization"; continue; } - - if (!IsRematerializable(candidate)) { + if (!CanBeRematerialized(candidate, remat_able)) { VLOG(5) << "candidate " << candidate->name() << " not viable: is not rematerializable"; continue; @@ -976,6 +988,9 @@ StatusOr HloRematerialization::RematerializeComputation( // blacklist. tensorflow::gtl::FlatSet remat_move_instructions; + // The map from instructions to their rematerializable status. + tensorflow::gtl::FlatMap remat_able; + // The peak memory of the computation at any point in the instruction // sequence. int64 peak_memory = memory_tracker.memory_usage(); @@ -1013,7 +1028,7 @@ StatusOr HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); Item* best_item = PickRematerializationCandidate( - memory_tracker, instruction_list, memory_limit_bytes); + memory_tracker, instruction_list, memory_limit_bytes, &remat_able); if (best_item == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 83de54f3fa56ee660b79d8c366dbc0b52f9fde87..e81334d5a84268a129cd4e90091e97dc23243226 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -40,7 +41,8 @@ class HloRematerializationTest : public HloTestBase { // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: // - // F32[] %param = {...} + // F32[1] %param = {...} + // F32[] %reshape = reshape(F32[], param) // F32[1024] %bcast = broadcast(%param) // F32[1024] %negate = negate(%bcast) // F32[2048] %concat_1 = concat({%negate, %negate}) @@ -57,9 +59,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast)); auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate( @@ -100,9 +104,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto slice_1 = builder.AddInstruction( HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, /*limit_indices=*/{1}, @@ -135,6 +141,15 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } + StatusOr RunHloRematerialization( + int64 memory_limit_bytes, HloModule* module, + SequentialHloOrdering::HloModuleSequence* sequence) { + TF_EXPECT_OK(verifier().Run(module).status()); + return HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, + sequence); + } + // Various shapes used in the canned computations. const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); @@ -158,11 +173,9 @@ TEST_F(HloRematerializationTest, SingleComputation) { SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/14 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -188,18 +201,16 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/20 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, + module.get(), &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -225,23 +236,21 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/17 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 8); + EXPECT_EQ(body_computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -264,20 +273,18 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/15 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // Both computations should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(body_computation->instruction_count(), 8); + // Both computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(body_computation->instruction_count(), 9); } // Test rematerialization of a doubly nested computation. All computations @@ -303,24 +310,22 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/middle_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(middle_computation->instruction_count(), 6); - EXPECT_EQ(inner_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(middle_computation->instruction_count(), 7); + EXPECT_EQ(inner_computation->instruction_count(), 8); // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/13 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // All computations should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(middle_computation->instruction_count(), 7); - EXPECT_EQ(inner_computation->instruction_count(), 8); + // All computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(middle_computation->instruction_count(), 9); + EXPECT_EQ(inner_computation->instruction_count(), 9); } TEST_F(HloRematerializationTest, RngNotRematerialized) { @@ -382,10 +387,9 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, + bool changed, RunHloRematerialization( /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), DefaultMemoryScheduler, &sequence)); + module.get(), &sequence)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,11 +480,9 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -573,11 +575,9 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 48da1a505c9bea72378aaba7824548cca0eef447..e1f9d8efd4974055947438c8a2e15cb77d1b5c75 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -19,13 +19,12 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -37,7 +36,7 @@ HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } namespace { @@ -81,7 +80,7 @@ HloRunner::ReadModuleFromHloTextFile(const std::string& filename, filename, &hlo_string)); HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } HloRunner::HloRunner(se::Platform* platform) { @@ -93,53 +92,108 @@ HloRunner::HloRunner(se::Platform* platform) { HloRunner::~HloRunner() {} -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes) { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - CreateExecutable(std::move(module), run_hlo_passes)); - se::Stream stream(backend().default_stream_executor()); - stream.Init(); +StatusOr HloRunner::TransferLiteralToDevice( + const Literal& literal) { + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + literal.shape(), backend().memory_allocator(), + backend().default_device_ordinal())); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + backend().default_stream_executor(), literal, buffer)); + return std::move(buffer); +} - ServiceExecutableRunOptions service_run_options(GetServiceRunOptionsForDevice( - backend().default_device_ordinal(), &stream, nullptr)); - const ExecutableRunOptions& run_options = service_run_options.run_options(); +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals) { + std::vector buffers; + for (const Literal* literal : literals) { + CHECK(literal != nullptr); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + TransferLiteralToDevice(*literal)); + buffers.push_back(std::move(buffer)); + } + return std::move(buffers); +} - // Copy arguments to device. - std::vector argument_buffers; - for (Literal* argument : arguments) { - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer argument_buffer, - backend().transfer_manager()->AllocateScopedShapedBuffer( - argument->shape(), run_options.allocator(), - run_options.device_ordinal())); - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - stream.parent(), *argument, argument_buffer)); - argument_buffers.push_back(std::move(argument_buffer)); +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals) { + std::vector literal_pointers; + literal_pointers.reserve(literals.size()); + for (const auto& literal : literals) { + literal_pointers.push_back(literal.get()); } + return TransferLiteralsToDevice(literal_pointers); +} + +StatusOr> HloRunner::TransferLiteralFromDevice( + const ShapedBuffer& buffer) { + return backend().transfer_manager()->TransferLiteralFromDevice( + backend().default_stream_executor(), buffer); +} - std::vector argument_buffer_ptrs; - argument_buffer_ptrs.reserve(argument_buffers.size()); - for (const auto& buf : argument_buffers) { - argument_buffer_ptrs.push_back(&buf); +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + TF_ASSIGN_OR_RETURN(std::vector argument_buffers, + TransferLiteralsToDevice(arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_buffers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile)); + return TransferLiteralFromDevice(result); +} + +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice> arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(argument.get()); } + return Execute( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); +} - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result, - executable->ExecuteOnStreamWrapper( - &service_run_options, /*profile=*/nullptr, argument_buffer_ptrs)); - - auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice( - stream.parent(), result); - if (result_literal.ok()) { - VLOG(4) << "Executed binary and got result: " - << result_literal.ValueOrDie()->ToString(); - } else { - VLOG(4) << "Executed binary and got status: " - << result_literal.status().ToString(); +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Get service run options. + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + ServiceExecutableRunOptions service_run_options = + GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, + nullptr); + + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + CreateExecutable(std::move(module), run_hlo_passes)); + return executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments); +} + +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); } - return result_literal; + return ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); } StatusOr>> HloRunner::ExecuteReplicated( @@ -171,7 +225,7 @@ StatusOr>> HloRunner::ExecuteReplicated( int64 device = device_assignment(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); - streams.push_back(absl::make_unique(executor)); + streams.push_back(MakeUnique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( device, streams.back().get(), &device_assignment)); @@ -198,7 +252,7 @@ StatusOr>> HloRunner::ExecuteReplicated( num_threads += options.num_replicas; } if (num_threads > 0) { - pool = absl::make_unique( + pool = MakeUnique( tensorflow::Env::Default(), "infeed_outfeed", /*num_threads=*/num_threads); } @@ -229,7 +283,7 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = absl::make_unique(); + auto literal = MakeUnique(); TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, options.outfeed_shape, literal.get())); if (options.outfeed_values != nullptr) { @@ -296,4 +350,8 @@ Backend& HloRunner::backend() { return *backend_; } +const Backend& HloRunner::backend() const { + return const_cast(this)->backend(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 53f7c6fe4a09111c5ee24f2290f0f4aeed0a4401..65537f07f56e74b7fe2c2f9792af21efc7229573 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -102,6 +102,15 @@ class HloRunner { static StatusOr> ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options); + // Transfers data between the host and device. + StatusOr TransferLiteralToDevice(const Literal& literal); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals); + StatusOr> TransferLiteralFromDevice( + const ShapedBuffer& buffer); + // Executes the given module with given literals as input and returns the // result as a Literal. // @@ -109,20 +118,25 @@ class HloRunner { // optimization. StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes = true); + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr> Execute( std::unique_ptr module, const tensorflow::gtl::ArraySlice> arguments, - bool run_hlo_passes = true) { - // Construct a vector of plain pointers for the arguments. - std::vector argument_pointers; - c_transform( - arguments, std::back_inserter(argument_pointers), - [](const std::unique_ptr& literal) { return literal.get(); }); - return Execute(std::move(module), argument_pointers, run_hlo_passes); - } + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + // As Execute(), but accepts and returns device buffers instead of host + // buffers. + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as @@ -137,6 +151,7 @@ class HloRunner { // This creates the backend lazily so it's possible to instantiate an // HloRunner in a program without any backends linked in. Backend& backend(); + const Backend& backend() const; private: // Creates an executable object given an HLO module. If run_hlo_passes is diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 23ace5afeab30d658e53258a7120d4a329cc90db..68b2cde83a2eb479d9ba71fc6eab9ac9ab1c8267 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -62,7 +62,34 @@ StatusOr MinimumMemoryForSequence( namespace { // Class implementing a list scheduler of HLO instructions which produces a -// sequence which minimizes memory usage. +// sequence which minimizes memory usage by preferring to schedule the node that +// frees bigger buffer and defines smaller outputs. +// +// Note that list scheduler is a greedy algorithm which cannot guarantee a +// global optimal solution. As a counterexample, considering the following +// graph: +// +// +--> B ===> C -------+ +// A -> | | +// | v +// +--> D ---> F=======>G +// | ^ +// | | +// +--> E -----+ +// +// --> : Buffer with size 1 +// ==> : Buffer with size 2 +// +// The list scheduler will always try to defer scheduling B in a greedy way +// since its output buffer is bigger than input. The sequence it creates will +// be: +// A D E F B C G +// , which has a maximum memory usage of 6 (B is alive while F is executing). +// +// An optimal way to shedule the previous graph is: +// A B C D E F G +// , which has a maximum memory usage of 5 (when F is executing). +// class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions @@ -70,8 +97,11 @@ class ListScheduler { static StatusOr> Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - ListScheduler scheduler(computation, points_to_analysis, size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + ListScheduler scheduler(computation, points_to_analysis, size_function, + memory_by_computation); return scheduler.CreateSchedule(); } @@ -92,10 +122,13 @@ class ListScheduler { ListScheduler(const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) : computation_(computation), points_to_analysis_(points_to_analysis), - size_function_(size_function) { + size_function_(size_function), + memory_by_computation_(memory_by_computation) { // Create a map containing the LogicalBuffer uses for each HLO // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by @@ -185,6 +218,12 @@ class ListScheduler { } // Returns the number of bytes freed if the HLO instruction is scheduled. + // If the instruction calls subcomputations, we count the memory used by the + // subcomputations as memory "defined" by the instruction. This is not + // entirely accurate, because subcomputation memory will be freed after the + // instruction finishes. But it is more accurate than not taking + // subcomputations into account at all. In the future, we may improve + // accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { @@ -194,7 +233,19 @@ class ListScheduler { freed_bytes += size_function_(*buffer); } } - return freed_bytes - entry.bytes_defined; + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : entry.instruction->called_computations()) { + auto it = memory_by_computation_.find(c); + if (it != memory_by_computation_.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; } // Constructs the scheduling priority of the given instruction. @@ -248,6 +299,8 @@ class ListScheduler { auto best_it = ready_queue.end(); --best_it; const HloInstruction* best = best_it->second.instruction; + VLOG(2) << "Schedule instruction: " << best->ToShortString() + << " Bytes freed: " << best_it->first.first; ready_queue.erase(best_it); ready_instructions.erase(best); schedule.push_back(best); @@ -315,6 +368,11 @@ class ListScheduler { const HloComputation& computation_; const TuplePointsToAnalysis& points_to_analysis_; const LogicalBuffer::SizeFunction& size_function_; + // Computations are analyzed in post-order. When scheduling an instruction + // that includes subcomputations, such as a while loop, we use this map to + // look up the memory needed by subcomputations. + const tensorflow::gtl::FlatMap& + memory_by_computation_; // A map containing the LogicalBuffers that each instruction uses. tensorflow::gtl::FlatMap> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm) { + return algorithm(computation, points_to_analysis, size_function, + memory_by_computation); + } + return DefaultMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation); +} + +} // namespace + StatusOr MinimumMemoryForComputation( const HloComputation& computation, const std::vector& sequence, @@ -352,24 +428,12 @@ StatusOr MinimumMemoryForComputation( return result.heap_size; } -StatusOr> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { - VLOG(2) << "Computation: " << computation.name(); - if (algorithm) { - return algorithm(computation, points_to_analysis, size_function); - } - return DefaultMemoryScheduler(computation, points_to_analysis, size_function); -} - -} // namespace - StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { // This ordering is based on DFS post-order, with a heuristic to decide which // operand to visit first. The heuristic is based on 'extra_users', which is // simply users-1 for each instruction. By subtracting 1, we're saying that @@ -395,6 +459,13 @@ StatusOr> DFSMemoryScheduler( extra_users[hlo] += extra_users[operand]; total_sizes[hlo] += total_sizes[operand]; } + // total_sizes[hlo] transitively includes the sizes of all nodes that + // lead to it. But computation is a DAG, so we are double-counting nodes, + // which can lead to overflows for large programs. + // cumulative_total_size caps the size to prevent overflows. + // NOTE(dimvar): this is quite ugly and should be changed. It's unclear + // why we care about transitive sizes; when scheduling a node, its input + // and output buffers should be all that matters, not its "history". total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); } CHECK_EQ(extra_users.size(), computation.instruction_count()); @@ -421,19 +492,24 @@ StatusOr> DFSMemoryScheduler( })); CHECK_EQ(sequence.size(), computation.instruction_count()); return sequence; -} +} // namespace xla StatusOr> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - return ListScheduler::Run(computation, points_to_analysis, size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + return ListScheduler::Run(computation, points_to_analysis, size_function, + memory_by_computation); } StatusOr> PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { const auto& post_order = computation.MakeInstructionPostOrder(); return std::vector{post_order.begin(), post_order.end()}; @@ -442,26 +518,30 @@ StatusOr> PostOrderMemoryScheduler( StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // We try both a list-scheduler based ordering and a DFS based ordering, and - // choose whichever returns a lower min-memory, not accounting for - // fragmentation. - // - // Note that this is just a heuristic. One obvious inaccuracy is that the - // memory required for sub-computations might be different when considered - // within the caller's context. But it's good enough for now. + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + // We try a few schedulers and choose whichever returns a lower min-memory, + // not accounting for fragmentation. + // - List is a scheduler that uses greedy heuristics. + // - DFS visits HLOs in postorder, with a heuristic to decide the order of + // children. + // - Postorder does not use any heuristics. + // List wins for most of our benchmarks; postorder-based schedulers win for + // some RNNs. TF_ASSIGN_OR_RETURN( std::vector list_sequence, - ListMemoryScheduler(computation, points_to_analysis, size_function)); + ListMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); TF_ASSIGN_OR_RETURN( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, points_to_analysis, size_function)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN( - std::vector dfs_sequence, - DFSMemoryScheduler(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, + DFSMemoryScheduler(computation, points_to_analysis, + size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, @@ -470,7 +550,8 @@ StatusOr> DefaultMemoryScheduler( TF_ASSIGN_OR_RETURN( std::vector post_order_sequence, - PostOrderMemoryScheduler(computation, points_to_analysis, size_function)); + PostOrderMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); TF_ASSIGN_OR_RETURN( const int64 post_order_memory, MinimumMemoryForComputation(computation, post_order_sequence, @@ -478,19 +559,20 @@ StatusOr> DefaultMemoryScheduler( VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); - if (post_order_memory < std::min(list_memory, dfs_memory)) { - VLOG(2) << "Chose min-memory post_order sequence: " - << HumanReadableNumBytes(post_order_memory); - return post_order_sequence; + auto min_memory = std::min({dfs_memory, post_order_memory, list_memory}); - } else if (list_memory <= dfs_memory) { + if (min_memory == list_memory) { VLOG(2) << "Chose min-memory list sequence: " << HumanReadableNumBytes(list_memory); return list_sequence; - } else { + } else if (min_memory == dfs_memory) { VLOG(2) << "Chose min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); return dfs_sequence; + } else { + VLOG(2) << "Chose min-memory post_order sequence: " + << HumanReadableNumBytes(post_order_memory); + return post_order_sequence; } } @@ -501,24 +583,32 @@ CreateMemoryMinimizingSequence(const HloModule& module, SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); - for (const auto* computation : module.MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN( - sequence[computation], - CreateMemoryMinimizingSequence(*computation, *points_to_analysis, - size_function, algorithm)); + tensorflow::gtl::FlatMap memory_by_computation; + for (const auto* computation : module.MakeComputationPostOrder()) { + if (!computation->IsFusionComputation()) { + TF_ASSIGN_OR_RETURN(auto one_computation_sequence, + CreateMemoryMinimizingSequence( + *computation, *points_to_analysis, size_function, + algorithm, memory_by_computation)); + memory_by_computation[computation] = + MinimumMemoryForComputation(*computation, one_computation_sequence, + *points_to_analysis, size_function) + .ValueOrDie(); + sequence[computation] = std::move(one_computation_sequence); + } } return sequence; } StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { + const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); + tensorflow::gtl::FlatMap empty_map; return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function, algorithm); + size_function, nullptr, empty_map); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index fcb006f818fd1d55a09475042779dd60de945697..49b927eefd24f4e26df781dd8d2b977bedba2b80 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -34,32 +34,47 @@ StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); +// Returns the minimum memory required to compute the given computation, +// assuming no fragmentation. +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + // A memory scheduler computes an execution sequence for the HLO instructions in // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. typedef std::function>( const HloComputation&, const TuplePointsToAnalysis&, - const LogicalBuffer::SizeFunction&)> + const LogicalBuffer::SizeFunction&, + const tensorflow::gtl::FlatMap&)> MemorySchedulerAlgorithm; // List scheduler StatusOr> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // DFS-order scheduler StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // Naive Post Order scheduler StatusOr> PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, @@ -67,7 +82,9 @@ StatusOr> PostOrderMemoryScheduler( StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes @@ -78,10 +95,10 @@ CreateMemoryMinimizingSequence(const HloModule& module, const MemorySchedulerAlgorithm& algorithm = {}); // Overload of above that computes the sequence for a single computation. +// Currently only used by the GPU backend. StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}); + const LogicalBuffer::SizeFunction& size_function); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 92df7c1427f282ccdde2df494c41b3f2a98cf7b3..db7ef6f0d4bd96216ea07ccc75a51513822bf2e3 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -158,7 +158,7 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); @@ -190,5 +190,199 @@ ENTRY root { instructions_by_name.at("e"))); } +TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { + // %WhileCond (cond_param: f32[4]) -> pred[] { + // %cond_param = f32[4]{0} parameter(0) + // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } }) + // ROOT %not-equal-to = pred[] not-equal-to( + // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant) + // } + // %WhileBody (body_param: f32[4]) -> f32[4] { + // %body_param = f32[4]{0} parameter(0) + // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // ROOT %subtract = f32[4]{0} subtract( + // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) + // } + // %SubcomputationsNotAccounted () -> f32[2,4] { + // %constant.3 = f32[2,4]{1,0} constant( + // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) + // %transpose = f32[2,4]{1,0} transpose( + // f32[2,4]{1,0} %constant.3), dimensions={0,1} + // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2), + // condition=%WhileCond, + // body=%WhileBody + // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0} + // ROOT %add = f32[2,4]{1,0} add( + // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) + // } + + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + // transpose(matrix) + bcast(while) + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + HloInstruction* while_loop = + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + // Creates 32 bytes and frees 16 + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); + + HloInstruction* matrix = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2( + {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); + // Creates 32 bytes + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); + + // Creates 32 bytes and frees 64 + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); + + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence( + *module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }, + ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // This schedule is an example of List's greedy heuristics being suboptimal. + // The while_loop is more expensive than transpose, so it would have been + // better to schedule it first, instead of during the busy time. + EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); + EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); +} + +TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { + auto builder = HloComputation::Builder(TestName()); + const auto TUPLE_SIZE = 1; + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6}); + + // Wrap lit in abs because constants are considered free by + // IgnoreInstruction, and it skews the accounting. + auto lit = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 1, 1, 1, 1, 1}))); + auto abs_const = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); + + auto abs_abs1 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( + tensorflow::gtl::ArraySlice({abs_abs1}))); + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto abs_abs2 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, + tuple_elm, abs_abs2)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence(*module, + [&TUPLE_SIZE](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // tuple allocates the tuple buffer and doesn't free anything. + // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. + // abs_abs2 should be scheduled before tuple by List. + EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple)); +} + +TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5}); + HloComputation::Builder builder(TestName()); + + auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 1, 1, 1, 1}))); + auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 2, 3, 4, 5}))); + auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0, 2, 4, 6, 8}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul})); + + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3)); + + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto fusion = computation->CreateFusionInstruction( + {tuple, mul, add}, HloInstruction::FusionKind::kLoop); + + TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence( + *module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), 2); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // fusion allocates memory for the tuple elements and doesn't free anything, + // so it's more expensive than exp. + EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 7f7e3f7dab03ce0ad64bd0fcfe4ddd020d31bf56..4fbb7f69acca5a42bfe824a04e800a1a7018ef28 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -49,9 +49,6 @@ string HloSharding::ToString() const { return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); } - string result = StrCat("{", (replicated_ ? " replicated" : ""), - (maximal_ ? " maximal" : "")); - if (replicated_) { return "{replicated}"; } else if (maximal_) { @@ -126,6 +123,38 @@ std::vector HloSharding::TileLimitForDevice(int64 device) const { return index; } +StatusOr> HloSharding::AsShapeTree( + const Shape& shape) const { + if (IsTuple()) { + ShapeTree result(shape, HloSharding::Replicate()); + int64 num_leaves = result.leaf_count(); + TF_RET_CHECK(num_leaves == tuple_elements_.size()) + << "Shape " << ShapeUtil::HumanString(shape) << " has " << num_leaves + << " leaf nodes while this sharding has " << tuple_elements_.size(); + auto it = tuple_elements_.begin(); + for (auto& index_to_sharding : result.leaves()) { + index_to_sharding.second = *it++; + } + return std::move(result); + } else { + return ShapeTree(shape, *this); + } +} + +StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { + if (IsTuple()) { + // TODO(b/109903108): An empty tuple has one leaf for ShapeTree, while it + // has zero leaves for ShapeUtil. This needs cleanup. + int64 shape_leaves = + ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape); + TF_RET_CHECK(shape_leaves == tuple_elements_.size()) + << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves + << " leaf nodes while this sharding has " << tuple_elements_.size(); + return *this; + } + return Tuple(ShapeTree(shape, *this)); +} + StatusOr HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { @@ -370,11 +399,21 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, Shape sub_shape = ShapeUtil::GetSubshape(shape, index); ShapeTree sub_shape_tree(sub_shape, Replicate()); sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); - if (ShapeUtil::IsTuple(sub_shape)) { - return Tuple(sub_shape_tree); - } else { - return sub_shape_tree.element({}); + return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree) + : sub_shape_tree.element(ShapeIndex({})); +} + +tensorflow::gtl::optional HloSharding::ExtractSingleSharding() + const { + if (!IsTuple()) { + return *this; + } + for (int64 i = 1; i < tuple_elements_.size(); ++i) { + if (tuple_elements_[0] != tuple_elements_[i]) { + return tensorflow::gtl::optional(); + } } + return tuple_elements_.front(); } std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 2b8e757f42991f697df37d3d34bfdff6a36bc509..0a213311b4a936978fee2ee23a3b7317b612e494 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -72,8 +72,7 @@ class HloSharding { // elements for every leaf shape contained in the tuple. static HloSharding Tuple(const ShapeTree& sub_shardings) { std::vector flattened_list; - flattened_list.reserve( - std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end())); + flattened_list.reserve(sub_shardings.leaf_count()); for (const auto& index_to_sharding : sub_shardings.leaves()) { flattened_list.push_back(index_to_sharding.second); } @@ -99,6 +98,9 @@ class HloSharding { static bool IsReservedDevice(int64 device) { return device < 0; } OpSharding ToProto() const; + + // Note that this string canonically has outer curly braces, e.g. + // "{replicated}". string ToString() const; // Validate that this sharding can be applied to a tensor with shape `shape`. @@ -160,25 +162,27 @@ class HloSharding { // tuple, if IsTuple, or a ShapeTree with a single element containing this // sharding. Only the leaf elements are populated. This creates a new // ShapeTree object so is not cheap. + StatusOr> AsShapeTree(const Shape& shape) const; ShapeTree GetAsShapeTree(const Shape& shape) const { - if (IsTuple()) { - ShapeTree result(shape, HloSharding::Replicate()); - CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()), - tuple_elements_.size()); - auto it = tuple_elements_.begin(); - for (auto& index_to_sharding : result.leaves()) { - index_to_sharding.second = *it++; - } - return result; - } else { - return ShapeTree(shape, *this); - } + return AsShapeTree(shape).ValueOrDie(); } // Retrieves the sub sharding at a given index, out of a tuple sharding. // REQUIRES: IsTuple() HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; + // If the current sharding is a tuple sharding, return itself as result. + // Otherwise returns a tuple sharding for the input shape, with all the leaves + // having this object sharding. + StatusOr GetTupleSharding(const Shape& shape) const; + + // Extracts the sharding that is common within the current sharding. + // If the current sharding is not a tuple sharding, the current sharding will + // be returned. If it is a tuple, and all the tuple elements are common, the + // common element will be returned. Otherwise the optional will contain no + // value. + tensorflow::gtl::optional ExtractSingleSharding() const; + bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && @@ -208,6 +212,12 @@ class HloSharding { return h; } + struct Hasher { + size_t operator()(const HloSharding& sharding) const { + return sharding.Hash(); + } + }; + // Gets the tile shape. // REQUIRES: !IsTileMaximal() && !IsTuple() const Shape& tile_shape() const { return tile_shape_; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b4b071af46df19520f9ba1f1f632692d489de59 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -0,0 +1,394 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +namespace { + +struct PassThrough { + PassThrough(HloInstruction* user, HloInstruction* operand) + : user(user), operand(operand) {} + + HloInstruction* user = nullptr; + HloInstruction* operand = nullptr; +}; + +void SetSingleSharding(HloInstruction* instruction, + const HloSharding& sharding) { + VLOG(4) << " " << instruction->name() << " to " << sharding; + instruction->set_single_sharding(sharding); +} + +bool ShardingMatches(const HloSharding& sharding1, + const HloSharding& sharding2) { + auto single_sharding1 = sharding1.ExtractSingleSharding(); + if (single_sharding1) { + auto single_sharding2 = sharding2.ExtractSingleSharding(); + if (single_sharding2) { + return *single_sharding1 == single_sharding2; + } + } + // Anything which is not unique across all elements, gets a full sharding + // compare. + return sharding1 == sharding2; +} + +// When we create domains, they are never "empty", where with empty we mean +// that a kDomain instruction has as operand another kDomain instruction of the +// same kind. +// But when the HLO optimizations are run, empty domains can be created. +// For example: +// +// Domain(device=None, device=0) -> +// Tuple(device=0) -> +// GTE(device=0) -> +// Domain(device=0, device=None) +// +// In that case the tuple simplifier could create something like: +// +// Domain(device=None, device=0) -> Domain(device=0, device=None) +// +// Which is a so called empty domain. +// In the case above, crossing an empty domain which was transiting through +// device 0, requires the normalization phase to fixup the empty domain by +// adding back a Tuple+GTE pair with the proper device. +// One particular case where this can create problems is the result of the +// entry computation, where the GTE assignments are used by TF to tell the +// XLA where the results should be sent. +std::vector LocatePassThroughDomainLinks( + const DomainMetadata::Domain& domain) { + std::vector pass_through; + for (HloInstruction* instruction : domain.enter_domains) { + CHECK(instruction->opcode() == HloOpcode::kDomain) + << "Instruction is not a kDomain: " << instruction->ToString(); + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kDomain && + domain.exit_domains.count(user) != 0) { + pass_through.emplace_back(user, instruction); + VLOG(2) << "Found passthrough domain link:"; + VLOG(2) << " " << user->ToString(); + VLOG(2) << " " << instruction->ToString(); + } + } + } + return pass_through; +} + +Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + for (auto& pass_through : LocatePassThroughDomainLinks(domain)) { + HloInstruction* tuple = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateTuple({pass_through.operand})); + HloInstruction* gte = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(pass_through.operand->shape(), + tuple, 0)); + gte->set_sharding(sharding); + TF_RETURN_IF_ERROR( + pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + } + return Status::OK(); +} + +std::unique_ptr CloneShardingForDomain( + const HloSharding& sharding) { + auto single_sharding = sharding.ExtractSingleSharding(); + if (!single_sharding) { + return MakeUnique(sharding); + } + return MakeUnique(*single_sharding); +} + +Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + VLOG(4) << "Applying " << sharding << " sharding"; + for (HloInstruction* instruction : domain.instructions) { + // We only change instructions without sharding, since otherwise we might + // mess up with eventual HLO passes which has knowledge of it. + if (!instruction->has_sharding()) { + SetSingleSharding(instruction, sharding); + } else { + VLOG(4) << " " << instruction->name() << " already has sharding " + << instruction->sharding(); + } + } + return Status::OK(); +} + +// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree. +// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate() +// sharding will be returned. +ShapeTree GetTupleSharding(HloInstruction* tuple) { + if (tuple->has_sharding()) { + return tuple->sharding().GetAsShapeTree(tuple->shape()); + } + return ShapeTree(tuple->shape(), HloSharding::Replicate()); +} + +// Retrieves the sharding of operand, asked from a user instruction which is +// within domain. If operand is a kDomain, it means that sharding argument is +// the operand sharding, otherwise the operand's own sharding will be returned. +const HloSharding* GetOperandSharding(const HloInstruction* operand, + const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + DCHECK_EQ(domain.reach_set.count(const_cast(operand)), 1); + // Here the user of operand is within the domain instruction set, and since it + // is user of operand, we need to look into the enter_domains set. If this is + // not a kDomain within the user domains set, then return the operand + // sharding, if any. + if (operand->opcode() != HloOpcode::kDomain || + domain.enter_domains.count(const_cast(operand)) == 0) { + return operand->has_sharding() ? &operand->sharding() : nullptr; + } + // At this point operand is a kDomain of the currently processed domain, so we + // can refer to sharding as the domain sharding. + return &sharding; +} + +// Tries to propagate the sharding information into the instructions that are +// part of the domain, in a post order manner (operand propagate to user). +StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + int64 assigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (instruction->has_sharding()) { + continue; + } + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + HloInstruction* tuple = instruction->mutable_operand(0); + const HloSharding* tuple_sharding = + GetOperandSharding(tuple, domain, sharding); + if (tuple_sharding != nullptr) { + if (tuple_sharding->IsTuple()) { + HloSharding sub_sharding = tuple_sharding->GetSubSharding( + tuple->shape(), {instruction->tuple_index()}); + VLOG(4) << " " << instruction->name() << " to sharding " + << sub_sharding; + instruction->set_sharding(sub_sharding); + } else { + SetSingleSharding(instruction, *tuple_sharding); + } + ++assigned; + } + } else if (instruction->opcode() == HloOpcode::kTuple) { + int64 tuple_assigned = 0; + ShapeTree shape_tree = GetTupleSharding(instruction); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + const HloSharding* operand_sharding = + GetOperandSharding(instruction->operand(i), domain, sharding); + if (operand_sharding != nullptr && + shape_tree.element({i}) != *operand_sharding) { + *shape_tree.mutable_element({i}) = *operand_sharding; + ++tuple_assigned; + } + } + if (tuple_assigned > 0) { + HloSharding tuple_sharding = HloSharding::Tuple(shape_tree); + VLOG(4) << " " << instruction->name() << " to sharding " + << tuple_sharding; + instruction->set_sharding(tuple_sharding); + ++assigned; + } + } else { + // If all the operand of the given instruction has the same single device + // assignment, assign that device to this instruction as well. + const HloSharding* common_sharding = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + const HloSharding* operand_sharding = + GetOperandSharding(operand, domain, sharding); + if (operand_sharding != nullptr) { + if (common_sharding != nullptr && + *common_sharding != *operand_sharding) { + common_sharding = nullptr; + break; + } + common_sharding = operand_sharding; + } + } + if (common_sharding != nullptr) { + VLOG(4) << " " << instruction->name() << " to sharding " + << *common_sharding; + instruction->set_sharding(*common_sharding); + ++assigned; + } + } + } + return assigned; +} + +Status ApplyDomainSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + auto single_sharding = sharding.ExtractSingleSharding(); + if (single_sharding) { + // Shortcut the simple case. We have a unique sharding, so we call + // the ApplyDomainSingleSharding() API which will apply array or tuple + // shaped sharding to the domain instructions. + return ApplyDomainSingleSharding(domain, *single_sharding); + } + VLOG(1) << "Assigning non-trivial sharding " << sharding; + for (;;) { + TF_ASSIGN_OR_RETURN(int64 assigned, + ApplyDomainShardingPass(domain, sharding)); + if (assigned == 0) { + break; + } + } + int64 unassigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (!instruction->has_sharding()) { + LOG(WARNING) << "Unassigned instruction: " << instruction->ToString(); + ++unassigned; + } + } + // Should we error out if unassigned > 0? + return Status::OK(); +} + +// Creates a kDomain instruction to be placed between instruction and operand. +// The kDomain instruction will be created only if the sharding differ between +// the instruction and the operand. +std::unique_ptr CreateDomain(HloInstruction* instruction, + HloInstruction* operand) { + const HloSharding* instruction_sharding = + instruction->has_sharding() ? &instruction->sharding() : nullptr; + const HloSharding* operand_sharding = + operand->has_sharding() ? &operand->sharding() : nullptr; + // No need for domain if they both have no sharding. + if (instruction_sharding == nullptr && operand_sharding == nullptr) { + return nullptr; + } + // No need for domain if they match. + if (instruction_sharding != nullptr && operand_sharding != nullptr && + ShardingMatches(*instruction_sharding, *operand_sharding)) { + return nullptr; + } + std::unique_ptr real_instruction_sharding; + std::unique_ptr real_operand_sharding; + if (instruction_sharding != nullptr) { + real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); + } + if (operand_sharding != nullptr) { + real_operand_sharding = CloneShardingForDomain(*operand_sharding); + } + VLOG(3) << "Creating domain:"; + VLOG(3) << " Instruction: " << instruction->name(); + VLOG(3) << " Operand: " << operand->name(); + VLOG(3) << " User side sharding: " + << (real_instruction_sharding != nullptr + ? real_instruction_sharding->ToString() + : "None"); + VLOG(3) << " Operand side sharding: " + << (real_operand_sharding != nullptr + ? real_operand_sharding->ToString() + : "None"); + + std::unique_ptr operand_side_metadata = + MakeUnique(std::move(real_operand_sharding)); + std::unique_ptr user_side_metadata = + MakeUnique(std::move(real_instruction_sharding)); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +StatusOr> ExtractOriginalCommonSharding( + tensorflow::gtl::ArraySlice instructions) { + // If we are here, all the instructions being passed had the same sharding + // (or no sharding), by the means of the ShardingMatches() API. + // As such, no kDomain was inserted, and here we are asked to extract the + // original common sharding. + // All the instructions passed to this API are part of the same computation. + const HloSharding* sharding = nullptr; + for (HloInstruction* instruction : instructions) { + if (instruction->has_sharding()) { + if (sharding == nullptr) { + sharding = &instruction->sharding(); + } else { + TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding())) + << "Sharding " << *sharding << " does not match the one in " + << instruction->ToString(); + } + } + } + if (sharding == nullptr) { + return std::unique_ptr(); + } + VLOG(4) << "Extracted sharding is " << *sharding; + return CloneShardingForDomain(*sharding); +} + +} // namespace + +std::unique_ptr ShardingMetadata::Clone() const { + std::unique_ptr sharding; + if (sharding_ != nullptr) { + sharding = MakeUnique(*sharding_); + } + return MakeUnique(std::move(sharding)); +} + +bool ShardingMetadata::Matches(const DomainMetadata& other) const { + const ShardingMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a ShardingMetadata, then it is clearly a no match. + return false; + } + if (sharding_ == nullptr) { + return other_ptr->sharding_ == nullptr; + } + return other_ptr->sharding_ != nullptr + ? ShardingMatches(*sharding_, *other_ptr->sharding_) + : false; +} + +string ShardingMetadata::ToString() const { + return sharding_ != nullptr ? sharding_->ToString() : "None"; +} + +Status ShardingMetadata::NormalizeInstructions( + const DomainMetadata::Domain& domain) const { + if (sharding_ != nullptr) { + VLOG(4) << "Normalizing sharding to " << sharding_->ToString() << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding_)); + TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding_)); + } + return Status::OK(); +} + +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain) { + TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + ExtractOriginalCommonSharding(domain.instructions)); + if (sharding != nullptr) { + VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString() + << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); + } else { + VLOG(1) << "Unable to find common sharding"; + } + return Status::OK(); +} + +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand) { + return CreateDomain(instruction, operand); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..ec162c34904ee2dfac3daeeee37133282a9c9698 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// A DomainMetadata implementation that internally wraps a sharding attribute. +class ShardingMetadata : public DomainMetadata { + public: + explicit ShardingMetadata(std::unique_ptr sharding) + : sharding_(std::move(sharding)) {} + + std::unique_ptr Clone() const override; + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override; + + string ToString() const override; + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override; + + static tensorflow::StringPiece KindName() { return "sharding"; } + + private: + std::unique_ptr sharding_; +}; + +// Within a set of instructions which had common sharding attributes before +// entring the HLO passes pipeline, apply sharding heuristics and normalize the +// instructions whose sharding deviates from the one which is inferred as to be +// the original one. +// Policy wise, HLO passes are allowed to create new unassigned instructions, +// but if they do create assigned ones, they have to conform to the ones around. +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain); + +// Given an HLO graph edge between instruction and one of its operands, creates +// a ShardingMetadata based kDomain instruction if the sharding between +// instruction and operand changes. Returns nullptr if there is no need for a +// domain separation. +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 3bf0d25efb7fad78aeccdd9269c289950b2171ab..ee7133689b15348a18e6db9181199d5b25bf8143 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_sharding.h" - #include #include #include #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -312,5 +311,48 @@ TEST_F(HloShardingTest, OstreamTest) { EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); } +TEST_F(HloShardingTest, ParseHloString) { + auto check = [](const HloSharding& sharding) { + TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding, + ParseSharding(sharding.ToString())); + EXPECT_EQ(sharding, parsed_sharding); + }; + check(HloSharding::Replicate()); + check(HloSharding::AssignDevice(2)); + check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}}))); + // Empty tuple. + check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), {})); + { + // Non-nested tuple. + auto tuple_shape = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 1, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 7})}); + check(HloSharding::Tuple( + tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}})), + HloSharding::Replicate(), HloSharding::AssignDevice(1)})); + } + { + // Nested tuple. + auto tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 1, 5, 7}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 7})})}); + std::vector leaf_shardings = { + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}})), + HloSharding::Replicate(), HloSharding::AssignDevice(1)}; + ShapeTree sharding_tree(tuple_shape, HloSharding::Replicate()); + // Assign leaf_shardings to sharding_tree leaves. + auto it = leaf_shardings.begin(); + for (auto& index_to_sharding : sharding_tree.leaves()) { + index_to_sharding.second = *it++; + } + check(HloSharding::Tuple(sharding_tree)); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h similarity index 84% rename from tensorflow/compiler/xla/tools/parser/hlo_token.h rename to tensorflow/compiler/xla/service/hlo_token.h index 7928bee5c2097f353b182095a555c334d7b69c95..533429608bc2e13626a3e746fbe465398e1f4bb4 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/service/hlo_token.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ #include @@ -22,9 +22,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Defines different kinds of tokens in a hlo module string. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. enum class TokKind { // Markers kEof, @@ -72,7 +74,6 @@ enum class TokKind { string TokKindToString(TokKind kind); -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 096ebb7946e08ba697a2c5eb93a71255586e489d..9034073cc8a82311297ccd087741e6713110a5a7 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -106,9 +106,7 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -Status ShapeVerifier::HandleInfeed(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // Outfeed has a separate shape field for the value which is outfed to the @@ -127,12 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { } Status ShapeVerifier::HandleHostCompute(HloInstruction*) { - return tensorflow::Status::OK(); + return Status::OK(); } -Status ShapeVerifier::HandleRng(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { return CheckShape( @@ -164,7 +160,7 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { @@ -183,7 +179,7 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { operand_shape.dimensions(operand_dimension)) << broadcast->ToString() << " operand shape " << operand_shape; } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { @@ -191,7 +187,7 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == ShapeUtil::ElementsIn(reshape->operand(0)->shape())); - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { @@ -201,21 +197,17 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { } Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { - return tensorflow::Status::OK(); + return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleCall(HloInstruction* call) { // The shape of kCall should match the shape of the computation it calls. return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); } -Status ShapeVerifier::HandleCustomCall(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleSlice(HloInstruction* slice) { return CheckShape(slice, @@ -384,6 +376,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: @@ -433,6 +426,14 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } +Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) { + std::vector operand_shapes; + for (const HloInstruction* operand : token->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(token, ShapeInference::InferTokenShape(operand_shapes)); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with @@ -497,7 +498,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, ShapeUtil::HumanString(instruction->shape()).c_str(), instruction->ToString().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -547,7 +548,7 @@ Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, instr1->ToString().c_str(), instr1->channel_id(), instr2->ToString().c_str(), instr2->channel_id()); } - return tensorflow::Status::OK(); + return Status::OK(); } string ComputationsToString( @@ -612,7 +613,7 @@ Status VerifyHloStructure(HloModule* module) { } } } - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { @@ -728,7 +729,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // TODO(b/65423525): We'd like to check that all operands are distinct. // This is currently disabled due to the invariant being violated by // multi-output fusion. - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { @@ -777,7 +778,7 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { "init: %s, body: %s", init->ToString().c_str(), body_root->ToString().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { @@ -795,9 +796,49 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { ShapeUtil::HumanString(operand_shape).c_str()); } } - return tensorflow::Status::OK(); + return Status::OK(); +} + +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. For example, TOKEN types have no Literal representation and cannot be +// on the interface of the entry computation (parameters and root instruction). +Status VerifyEntryAndExitShapes(const HloModule& module) { + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape()).c_str()); + } + } + if (ShapeContainsToken( + module.entry_computation()->root_instruction()->shape())) { + return InternalError( + "Entry root is or contains a token shape: %s", + ShapeUtil::HumanString( + module.entry_computation()->root_instruction()->shape()) + .c_str()); + } + return Status::OK(); } +} // namespace + StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); @@ -858,6 +899,8 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); } + TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 6208887547a14d22b512ef308dd2668af2f4468d..7283b3e7dcdbed5be18a1da1571287cf0c089288 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -81,10 +81,9 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; + Status HandleGenerateToken(HloInstruction* token) override; - Status FinishVisit(HloInstruction*) override { - return tensorflow::Status::OK(); - } + Status FinishVisit(HloInstruction*) override { return Status::OK(); } protected: // Check the instruction's shape against the shape given by ShapeInference diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index 13e4557317f74b3fb46f07fb91c339fd2f34752f..d7458c338e9f1df9fac90270845aae0b8f779ee2 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -27,6 +27,7 @@ using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; using tensorflow::strings::Printf; using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; string HumanReadableProfileBuilder::ToString() const { string s; @@ -35,20 +36,26 @@ string HumanReadableProfileBuilder::ToString() const { computation_name_.c_str(), HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); - auto append_op = [&](const OpInfo& op) { + auto print_op = [&](const OpInfo& op) { + // Skip ops with 0 optimal seconds and 0 actual cycles. These are ops that + // were expected to be free and are actually free -- things like (on most + // backends) kParameter or kConstant HLOs. There's no need to clutter the + // profile with these. + if (op.optimal_seconds == 0 && op.cycles == 0) { + return; + } + string bytes_per_sec; string bytes_per_cycle; - if (op.cycles <= 0 || op.bytes_accessed < 0) { - bytes_per_sec = ""; - bytes_per_cycle = ""; - } else { - bytes_per_sec = - HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)); + if (op.cycles > 0 && op.bytes_accessed >= 0) { + bytes_per_sec = StrCat( + HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)), + "/s"); + double bpc = static_cast(op.bytes_accessed) / op.cycles; if (op.bytes_accessed > op.cycles) { - bytes_per_cycle = HumanReadableNumBytes(op.bytes_accessed / op.cycles); + bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle"); } else { - bytes_per_cycle = - Printf("%.3fB", static_cast(op.bytes_accessed) / op.cycles); + bytes_per_cycle = Printf("%.3fB/cycle", bpc); } } @@ -59,14 +66,16 @@ string HumanReadableProfileBuilder::ToString() const { double nsecs = op.cycles / clock_rate_ghz_; Appendf(&s, - "%15lld cycles (%6.2f%%) :: %12.1f usec (%12.1f optimal) :: %18s " - ":: %18s :: %12s/s :: %12s/cycle :: %s\n", + "%15lld cycles (%6.2f%%) :: %12.1f usec %22s :: %18s " + ":: %18s :: %14s :: %16s :: %s\n", op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles), - op.optimal_seconds * 1e6, + op.optimal_seconds < 0 + ? "" + : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), op.flop_count <= 0 - ? "" + ? "" : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), - op.transcendental_count <= 0 ? "" + op.transcendental_count <= 0 ? "" : HumanReadableNumTranscendentalOps( op.transcendental_count, nsecs) .c_str(), @@ -78,24 +87,26 @@ string HumanReadableProfileBuilder::ToString() const { int64 total_transcendentals = 0.; int64 total_bytes = 0; for (const auto& op : op_infos_) { - optimal_seconds_sum += op.optimal_seconds; - total_flops += op.flop_count; - total_transcendentals += op.transcendental_count; - total_bytes += op.bytes_accessed; + if (op.optimal_seconds > 0) { + optimal_seconds_sum += op.optimal_seconds; + } + total_flops += std::max(op.flop_count, int64{0}); + total_transcendentals += std::max(op.transcendental_count, int64{0}); + total_bytes += std::max(op.bytes_accessed, int64{0}); } VLOG(1) << "Total floating point ops: " << total_flops; - append_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, - total_transcendentals, total_bytes, optimal_seconds_sum}); + print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, + total_transcendentals, total_bytes, optimal_seconds_sum}); - // Sort ops in decreasing order of cycles. + // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); std::sort( sorted_ops.begin(), sorted_ops.end(), [](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; }); for (const auto& op : sorted_ops) { - append_op(op); + print_op(op); } if (total_cycles_ <= 0) { @@ -109,8 +120,20 @@ string HumanReadableProfileBuilder::ToString() const { table.SetMetricName("microseconds above estimated optimum"); table.SetEntryName("ops"); table.SetShowCategoryTable(); + table.SetShowAllEntries(); float total_discrepancy_in_microseconds = 0.0f; - for (const auto& op : sorted_ops) { + for (const auto& op : op_infos_) { + // Skip ops with < 0 optimal seconds. These are ops for which we don't + // know the optimal time. + if (op.optimal_seconds < 0) { + continue; + } + // Also skip ops with 0 actual cycles. These ops were free; there's no + // need to clutter the "above estimated optimum" table with them, + // because they can't be optimized further. + if (op.cycles == 0) { + continue; + } MetricTableReport::Entry entry; entry.text = op.name; entry.short_text = op.short_name; @@ -128,7 +151,14 @@ string HumanReadableProfileBuilder::ToString() const { table.SetMetricName("microseconds"); table.SetEntryName("ops"); table.SetShowCategoryTable(); - for (const auto& op : sorted_ops) { + table.SetShowAllEntries(); + for (const auto& op : op_infos_) { + // Skip ops with 0 optimal seconds and 0 actual cycles. As in + // print_op(), these are uninteresting because they're expected to be + // free, and they were actually free. + if (op.cycles == 0 && op.optimal_seconds == 0) { + continue; + } MetricTableReport::Entry entry; entry.text = op.name; entry.short_text = op.short_name; @@ -139,6 +169,23 @@ string HumanReadableProfileBuilder::ToString() const { StrAppend(&s, table.MakeReport(CyclesToMicroseconds(total_cycles_))); } } + + if (total_bytes > 0) { + MetricTableReport table; + table.SetMetricName("MiB read+written"); + table.SetEntryName("ops"); + table.SetShowCategoryTable(); + for (const auto& op : op_infos_) { + MetricTableReport::Entry entry; + entry.text = op.name; + entry.short_text = op.short_name; + entry.category_text = op.category; + entry.metric = static_cast(op.bytes_accessed) / (1 << 20); + table.AddEntry(std::move(entry)); + } + StrAppend(&s, + table.MakeReport(static_cast(total_bytes) / (1 << 20))); + } return s; } diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index fb36d3a0d6532b4157152c49f08f4f247a7c6d89..6f56c3aa82e9d1c942fd67ff7a5948cf2e54370d 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -41,7 +41,8 @@ class HumanReadableProfileBuilder { int64 total_cycles() const { return total_cycles_; } // Adds an operation to the profile. If you don't know the number of - // floating-point ops or bytes touched by the op, pass -1 for that param. + // floating-point ops or bytes touched by the op, or if you don't know how + // fast it would run optimally, pass -1 for that param. void AddOp(tensorflow::StringPiece op_name, tensorflow::StringPiece short_name, tensorflow::StringPiece category, int64 cycles, int64 flop_count, @@ -62,10 +63,10 @@ class HumanReadableProfileBuilder { string short_name; string category; int64 cycles; - int64 flop_count; + int64 flop_count; // -1 if unknown int64 transcendental_count; - int64 bytes_accessed; - float optimal_seconds; + int64 bytes_accessed; // -1 if unknown + float optimal_seconds; // -1 if unknown }; double CyclesToSeconds(int64 cycles) const { diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b3fa6c1572cf0ed91fc427722edcb23d8b8529d --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -0,0 +1,733 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace gtl = ::tensorflow::gtl; + +namespace { +using Analysis = IndexedArrayAnalysis; +using UnknownArray = Analysis::UnknownArray; +using ConstantArray = Analysis::ConstantArray; +using ScalarIndexedArray = Analysis::ScalarIndexedArray; +using tensorflow::gtl::ArraySlice; +using tensorflow::str_util::Join; +} // namespace + +string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { + switch (root->kind()) { + case Array::kUnknown: { + auto* unknown_tensor = root->as(); + return tensorflow::strings::StrCat("%", + unknown_tensor->instruction().name()); + } + + case Array::kConstant: { + if (print_constants) { + string contents = root->as()->literal()->ToString(); + return tensorflow::strings::StrCat( + "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents, + ")"); + } + return tensorflow::strings::StrCat( + "(constant ", ShapeUtil::HumanString(root->shape()), ")"); + } + + case Array::kScalarIndexedConstant: + case Array::kScalarIndexed: { + auto* indexed_array = root->as(); + string name = root->kind() == Array::kScalarIndexedConstant + ? "scalar-indexed-const" + : "scalar-indexed"; + return tensorflow::strings::StrCat( + "(", name, " ", ToString(indexed_array->source(), print_constants), + " ", ToString(indexed_array->indices(), print_constants), " ", + indexed_array->source_dim(), "->[", + Join(indexed_array->output_dims(), ","), "])"); + } + } +} + +StatusOr IndexedArrayAnalysis::GetArrayFor( + const HloInstruction* instr) { + auto it = cache_.find(instr); + if (it != cache_.end()) { + return it->second; + } + + TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr)); + return FindOrDie(cache_, instr); +} + +Status IndexedArrayAnalysis::TraverseAndPopulateCache( + const HloInstruction* root) { + // Depth first search over the DAG, invoking ComputeArrayFor in post order. + // The HLO instructions already in the cache are considered leaves. + + gtl::InlinedVector stack; + + enum DfsState { kDiscovered, kVisited }; + gtl::FlatMap dfs_state_map; + + stack.push_back(root); + InsertOrDie(&dfs_state_map, root, kDiscovered); + + do { + const HloInstruction* instr = stack.back(); + if (cache_.count(instr)) { + stack.pop_back(); + continue; + } + + switch (FindOrDie(dfs_state_map, instr)) { + case kDiscovered: { + for (const HloInstruction* operand : instr->operands()) { + if (!cache_.count(operand)) { + stack.push_back(operand); + CHECK(!dfs_state_map.count(operand) || + dfs_state_map[operand] == kDiscovered); + dfs_state_map[operand] = kDiscovered; + } + } + dfs_state_map[instr] = kVisited; + break; + } + + case kVisited: + stack.pop_back(); + TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr)); + InsertOrDie(&cache_, instr, array); + break; + } + } while (!stack.empty()); + + return Status::OK(); +} + +StatusOr IndexedArrayAnalysis::ComputeArrayFor( + const HloInstruction* instr) { + Array* computed_array; + if (instr->IsElementwise() && instr->operand_count() == 1) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForElementwiseUnaryOp( + instr->opcode(), FindOrDie(cache_, instr->operand(0)))); + } else if (instr->IsElementwise() && instr->operand_count() == 2) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForElementwiseBinaryOp( + instr->opcode(), FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); + } else if (instr->opcode() == HloOpcode::kConstant) { + TF_ASSIGN_OR_RETURN(computed_array, + ComputeArrayForConstant(instr->literal())); + } else if (instr->opcode() == HloOpcode::kGather) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(), + instr->gather_window_bounds(), + FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); + } else if (instr->opcode() == HloOpcode::kReshape) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForReshape(instr->shape(), + FindOrDie(cache_, instr->operand(0)))); + } else { + computed_array = nullptr; + } + + if (!computed_array) { + computed_array = Construct(instr); + } + + return computed_array; +} + +StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( + const Literal& literal) { + return Construct(&literal); +} + +StatusOr IndexedArrayAnalysis::FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape) { + // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). + // `source` is the inner Gather(A, X). + + Array* a = source->source(); + Array* x = source->indices(); + Array* y = indices; + + // This bit is slightly tricky, so we do a naive "simulation" of the two + // consecutive gather operations to infer what the composed gather should look + // like. + + enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond }; + + std::vector simulated_index(a->shape().dimensions_size(), + IndexComponent::Ungathered); + + // Simulate the first gather. + EraseAt(&simulated_index, source->source_dim()); + for (int64 gather_dim : source->output_dims()) { + simulated_index.insert(simulated_index.begin() + gather_dim, + IndexComponent::GatheredFirst); + } + + // Simulate the second gather. + EraseAt(&simulated_index, source_dim); + for (int64 output_dim : output_dims) { + simulated_index.insert(simulated_index.begin() + output_dim, + IndexComponent::GatheredSecond); + } + + int64 source_dim_for_index_array = + FindIndex(source->output_dims(), source_dim); + CHECK_NE(source_dim_for_index_array, source->output_dims().size()); + + std::vector output_dims_for_index_array; + int64 gathered_index_components_seen = 0; + for (IndexComponent simulation_dim : simulated_index) { + if (simulation_dim == IndexComponent::GatheredSecond) { + output_dims_for_index_array.push_back(gathered_index_components_seen); + } + if (simulation_dim != IndexComponent::Ungathered) { + gathered_index_components_seen++; + } + } + + std::vector dim_sizes_for_composed_index; + std::vector output_dims_for_new_gather; + for (int64 i = 0, e = simulated_index.size(); i < e; i++) { + if (simulated_index[i] != IndexComponent::Ungathered) { + dim_sizes_for_composed_index.push_back(shape.dimensions(i)); + output_dims_for_new_gather.push_back(i); + } + } + + Array* inner_indices = ConstructScalarIndexedArray( + x, y, source_dim_for_index_array, output_dims_for_index_array, + ShapeUtil::MakeShape(x->shape().element_type(), + dim_sizes_for_composed_index)); + return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(), + output_dims_for_new_gather, + std::move(shape)); +} + +StatusOr IndexedArrayAnalysis::ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices) { + if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { + return nullptr; + } + + CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); + if (!c_binary_search(dim_numbers.elided_window_dims(), + dim_numbers.gather_dims_to_operand_dims(0))) { + return nullptr; + } + + int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); + std::vector output_dims; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + output_dims.push_back(i); + } + } + + if (auto* indexed = dynamic_cast(source)) { + auto it = c_find(indexed->output_dims(), source_dim); + if (it != indexed->output_dims().end()) { + return FoldGatherOfGather(indexed, indices, source_dim, output_dims, + shape); + } + } else if (auto* constant = dynamic_cast(source)) { + return Construct(constant, indices, source_dim, + output_dims, shape); + } + + return Construct(source, indices, source_dim, output_dims, + shape); +} + +namespace { +// Returns an index into `values` such that the product of the range +// [values.begin()+index, values.end()) is equal to `product`. If there is no +// such index, return -1. All integers in `values` must be positive. +int64 FindSuffixWithProduct(ArraySlice values, int64 product) { + DCHECK(c_all_of(values, [](int64 value) { return value > 0; })); + + int64 current_product = 1; + int64 i; + for (i = values.size() - 1; i >= 0 && product > current_product; --i) { + current_product *= values[i]; + } + + if (product == current_product) { + return i + 1; + } + + return -1; +} + +struct ReshapePassthroughDimPair { + int64 result_dim; + int64 operand_dim; +}; + +// Returns a set of dimension pairs such for all (result_dim, operand_dim) in +// the set: +// +// output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim] +// +// The returned vector of pairs is sorted in both the result_dim and the +// operand_dim components. +std::vector ComputeReshapePassthroughDimPairs( + ArraySlice operand_shape, ArraySlice result_shape) { + // A reshape can be seen as an index mapping from output index to input index: + // + // (i_0, ..., i_n) = f(o_0, ..., o_m) + // + // This function returns the pairs (j, k) for which the following invariant + // holds for all indices in the shape: + // + // o_j == i_k + // + // And this occurs when: + // + // O_{j+1} * ... * O_n == I_{k+1} * ... * I_m + // + // (where O_x are the sizes of the output shape and I_x are the sizes of the + // input shape) and the size of the dimension j of the result is the same as + // the size of dimension k in the operand. + // + // These conditions are sufficient because the Reshape HLO is spec'ed such + // that the rightmost dimensions are always minor in the flattening and refine + // operation. + + std::vector result; + int64 result_subarray_size = 1; + for (int64 result_dim = result_shape.size() - 1; result_dim >= 0; + --result_dim) { + int64 candidate_operand_dim = + FindSuffixWithProduct(operand_shape, result_subarray_size); + + // result_subarray_size does not include the elements in the current + // `result_dim` dimension (we multiply in result_shape[result_dim] at the + // end of loop body) so candidate_operand_dim can never be zero. + CHECK_NE(candidate_operand_dim, 0); + + if (candidate_operand_dim != -1 && + result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { + result.push_back({/*result_dim=*/result_dim, + /*operand_dim=*/candidate_operand_dim - 1}); + } + result_subarray_size *= result_shape[result_dim]; + } + + c_reverse(result); + + if (VLOG_IS_ON(3)) { + std::vector result_strings; + c_transform(result, std::back_inserter(result_strings), + [](ReshapePassthroughDimPair value) { + return tensorflow::strings::StrCat(value.result_dim, "->", + value.operand_dim); + }); + VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" + << Join(result_shape, ",") << "] passthrough indices are [" + << Join(result_strings, ",") << "]"; + } + + DCHECK(c_is_sorted( + result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { + return lhs.result_dim < rhs.result_dim; + })); + + DCHECK(c_is_sorted( + result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { + return lhs.operand_dim < rhs.operand_dim; + })); + + return result; +} + +// Return true if `dim` is stated as an passthrough operand dim in +// `passthrough_dims`. +bool IsReshapePassthroughOperandDim( + ArraySlice passthrough_dims, int64 dim) { + return c_any_of(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == dim; + }); +} + +// Maps `operand_dim` which must be an passthrough operand dimension to its +// corresponding passthrough result dimension based on `passthrough_dims`. +int64 MapPassthroughOperandDimToResultDim( + ArraySlice passthrough_dims, int64 operand_dim) { + auto it = c_find_if(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == operand_dim; + }); + CHECK(it != passthrough_dims.end()); + return it->result_dim; +} + +int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, + ArraySlice result_shape, + int64 source_passthrough_dim) { + int64 indexed_source_subarray_size = + std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, + operand_shape.end(), 1, std::multiplies()); + + return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); +} + +}; // namespace + +StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( + const Shape& shape, Array* operand) { + auto* scalar_indexed = dynamic_cast(operand); + if (!scalar_indexed) { + return nullptr; + } + + // Try to fold Reshape(ScalarIndexed(Const, Indices)) + // => ScalarIndexed(Const', Indices) + // + // We can view the reshape and the scalar-indexed operations as functions that + // map an output index (i.e. an index into the result) to an input index + // (i.e. an index into the operand). The key idea used here is that the + // output-to-input mapping for some reshape operations may "pass through" some + // output dimensions into the input space unchanged -- i.e. there may exist + // output dimension "O" and input dimension "I" such that OutputIndex[O] is + // always == InputIndexForReshape(OutputIndex)[I]. If these pass-through + // dimensions in the input space of the reshape happen to be include all the + // output dimensions for the scalar-indexed node then, roughly, the following + // holds: + // + // SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx)) + // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs)) + // + // Where Ps are the set of the pass-through components of Idx that are + // also the output dims of the scalar-indexed node, and Qs are the rest. + // For brevity, we're playing fast and loose with the notation here -- we + // don't literally require Idx to be a concatenation of Ps and Qs, as + // suggested by the "++". + // + // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs)) + // + // Again, we're playing fast and loose with the notation around "++". + // Generally this ++ will be a different function that the ++ in the + // previous step. + // + // If the scalar-indexed node has a constant as the source then the + // SourceIndexOfReshape function can be "folded into" the constant itself by + // reshaping it, leaving us with: + // + // == SourceIndexOfScalarIndexed(Ps ++ Qs) + // == SourceIndexOfScalarIndexed(Idx) + // + // which is just a scalar-indexed node (with parameters different from the + // scalar-indexed node we started with) with a reshaped constant as the + // source. + // + // We can't fold SourceIndexOfReshape into the constant without introducing + // another precondition: since the new scalar-indexed node will have a + // reshaped (constant) array as its source it will, in general, have a + // different source dimension than the original scalar-indexed node. This + // source dimension will have to be a passthrough dimension of the + // SourceIndexOfReshape indexing function that is folded into the source. And + // such a dimension need not exist so this is a non-trivial precondition. + + std::vector reshape_passthrough_dims = + ComputeReshapePassthroughDimPairs( + /*operand_shape=*/AsInt64Slice(operand->shape().dimensions()), + /*result_shape=*/AsInt64Slice(shape.dimensions())); + + auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) { + return IsReshapePassthroughOperandDim(reshape_passthrough_dims, + operand_dim); + }; + + if (!c_all_of(scalar_indexed->output_dims(), + is_reshape_passthrough_operand_dim)) { + return nullptr; + } + + // To compute the shape of the source for the new scalar-indexed node we're + // going to create, we first "undo" the scalar-indexed operation. + std::vector new_scalar_indexed_source_shape(shape.dimensions().begin(), + shape.dimensions().end()); + for (int64 i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) { + int64 output_dim = scalar_indexed->output_dims()[i]; + int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim( + reshape_passthrough_dims, output_dim); + EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape); + } + + // After this, we need to add in the dimension that will be the source + // dimension for the new scalar-indexed node. A scalar-indexed node "removes" + // the source dimensions and "adds" the output dimensions, so to get back to + // the shape for the *source* of the scalar-indexed node we need to remove the + // output dims (which we did above) and then add back the source dim (which we + // are about to do below): + + const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape(); + + int64 source_dim_for_new_scalar_indexed_node = + FindSourcePositionForPassthroughResultDim( + /*operand_shape=*/AsInt64Slice( + scalar_indexed_source_shape.dimensions()), + /*result_shape=*/new_scalar_indexed_source_shape, + scalar_indexed->source_dim()); + + // We may not be able to find a source dim for the new scalar-indexed node. + // For instance consider: + // + // operand = s32[3,5,2] constant({...}) + // indices = s32[7] parameter(0) + // gather = s32[3,2,7] gather(operand, indices), + // output_window_dims={0,1}, + // elided_window_dims={1}, + // gather_dims_to_operand_dims={1}, + // index_vector_dim=1, + // window_bounds={3,1,2} + // reshape = s32[6,7] reshape(gather) + // + // In this case the gather maps to: + // (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2]) + // + // and the reshape passes through dimension 2 from its input into dimension 1 + // in its output. However, we can't rewrite the reshape as a scalar-indexed + // node because then we'd have to reshape the [3,5,2] `operand` array to + // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently + // (a.k.a. isn't pass-through) than the [3,5,2] array. + + if (source_dim_for_new_scalar_indexed_node == -1) { + return nullptr; + } + + InsertAt( + &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, + scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); + + CHECK(IsReshapePassthroughOperandDim( + ComputeReshapePassthroughDimPairs( + /*operand_shape=*/AsInt64Slice( + scalar_indexed_source_shape.dimensions()), + /*result_shape=*/new_scalar_indexed_source_shape), + scalar_indexed->source_dim())); + + auto map_passthrough_operand_dim_to_result_dim = [&](int64 result_dim) { + return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims, + result_dim); + }; + + std::vector output_dims_for_new_scalar_indexed_node; + c_transform(scalar_indexed->output_dims(), + std::back_inserter(output_dims_for_new_scalar_indexed_node), + map_passthrough_operand_dim_to_result_dim); + + TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, + TakeOwnership(scalar_indexed->literal().Reshape( + new_scalar_indexed_source_shape))); + TF_ASSIGN_OR_RETURN( + Array * new_scalar_indexed_source, + ComputeArrayForConstant(*new_scalar_indexed_source_literal)); + + return ConstructScalarIndexedArray( + new_scalar_indexed_source, scalar_indexed->indices(), + source_dim_for_new_scalar_indexed_node, + output_dims_for_new_scalar_indexed_node, shape); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, + Array* rhs) { + // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices)) + // => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices) + // + // We can do this if every output dimension from the scalar-indexed node is a + // broadcasted dimension for the broadcast node. Informally, the precondition + // means Broadcast(Const0)[IDX] is solely a function of the components of IDX + // that are not output-dims for the scalar-indexed node. In other words, for + // every assignment to the non-output dims in IDX we have a "constant" LHS to + // the BinaryOp. This transform propagates this "constant" to the source for + // the scalar-indexed node. + + ScalarIndexedConstantArray* lhs_scalar_indexed_const = + dynamic_cast(lhs); + ScalarIndexedConstantArray* rhs_scalar_indexed_const = + dynamic_cast(rhs); + + bool lhs_is_indexed; + + // One of the operands must be scalar-indexed and the other must be a + // broadcast of a constant. + if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) { + lhs_is_indexed = true; + } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) { + lhs_is_indexed = false; + } else { + return nullptr; + } + + ScalarIndexedConstantArray* scalar_indexed_const = + lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const; + UnknownArray* candidate_broadcast_array = + dynamic_cast(lhs_is_indexed ? rhs : lhs); + if (!candidate_broadcast_array || + candidate_broadcast_array->instruction().opcode() != + HloOpcode::kBroadcast) { + return nullptr; + } + + const HloInstruction* broadcast_instr = + &candidate_broadcast_array->instruction(); + const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0); + if (broadcast_const_operand->opcode() != HloOpcode::kConstant) { + return nullptr; + } + + ArraySlice broadcast_dims = broadcast_instr->dimensions(); + auto is_broadcasted_dim = [&](int64 output_dim) { + return c_find(broadcast_dims, output_dim) == broadcast_dims.end(); + }; + + // All of the output dims must be "broadcasted" dims for the other operand. + if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) { + return nullptr; + } + + // To figure out the broadcast dimensions for the (constant) source for the + // scalar-indexed node, we "simulate" the index transformation done by the + // existing broadcsat: + enum class IndexComponent { Broadcasted, NotBroadcasted }; + std::vector simulated_index( + broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted); + for (int64 broadcast_dim : broadcast_dims) { + simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted; + } + + // The scalar-indexed node "removes" the source dim and "inserts" the output + // dims. We do the opposite here to undo the scalar-indexed operation. + ArraySlice output_dims = scalar_indexed_const->output_dims(); + for (int64 i = output_dims.size() - 1; i >= 0; --i) { + CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted); + EraseAt(&simulated_index, output_dims[i]); + } + + InsertAt(&simulated_index, scalar_indexed_const->source_dim(), + IndexComponent::Broadcasted); + + // new_inner_broadcast_dims holds the broadcast dimensions for the inner + // BinaryOp(Broadcast'(Const0), Const1). We now translate simulated_index to + // new_inner_broadcast_dims. + std::vector new_inner_broadcast_dims; + for (int64 i = 0; i < simulated_index.size(); i++) { + if (simulated_index[i] == IndexComponent::NotBroadcasted) { + new_inner_broadcast_dims.push_back(i); + } + } + + // inner_broadcast_result is the Broadcast'(Const0) bit in + // BinaryOp(Broadcast'(Const0), Const1) + TF_ASSIGN_OR_RETURN( + std::unique_ptr inner_broadcast_result, + broadcast_const_operand->literal().Broadcast( + scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); + + // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1) + const Literal* literal_for_new_source; + if (lhs_is_indexed) { + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + opcode, scalar_indexed_const->literal(), *inner_broadcast_result))); + } else { + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + opcode, *inner_broadcast_result, scalar_indexed_const->literal()))); + } + + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, scalar_indexed_const->indices(), + scalar_indexed_const->source_dim(), + std::vector(scalar_indexed_const->output_dims().begin(), + scalar_indexed_const->output_dims().end()), + scalar_indexed_const->shape()); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand) { + auto* scalar_indexed_const = + dynamic_cast(operand); + if (scalar_indexed_const == nullptr) { + return nullptr; + } + + // Fold UnaryOp(ScalarIndexed(Const, Indices)) + // => ScalarIndexed(UnaryOp(Const), Indices) + + TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp( + opcode, scalar_indexed_const->literal()))); + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, scalar_indexed_const->indices(), + scalar_indexed_const->source_dim(), + std::vector(scalar_indexed_const->output_dims().begin(), + scalar_indexed_const->output_dims().end()), + scalar_indexed_const->shape()); +} + +tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { + return "indexed-array-analysis-printer-pass"; +} + +StatusOr IndexedArrayAnalysisPrinterPass::Run(HloModule* module) { + if (!VLOG_IS_ON(2)) { + return false; + } + + IndexedArrayAnalysis analysis; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instr : computation->instructions()) { + TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr)); + if (!dynamic_cast(t) && !dynamic_cast(t)) { + VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t); + } + } + } + + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..ce92fd2919c90fa8a2fb7b796ed6f0fdaf48fe62 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -0,0 +1,326 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { + +// IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a +// gather from another array. It does this by mapping HLO instructions to +// instances of IndexedArrayAnalysis::Array, which can be inspected to discover +// whether said HLO is equivalent to a gather. +class IndexedArrayAnalysis { + public: + // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array. + // Array really just a sum type of the classes that inherit from it. The + // meaning of each of the subtypes is documented on the subtype declaration. + // + // Array instances are immutable once created. + class Array { + public: + enum Kind { kUnknown, kConstant, kScalarIndexedConstant, kScalarIndexed }; + + virtual Kind kind() const = 0; + virtual const Shape& shape() const = 0; + + // Does a checked downcast from `Array` to `T` which must be one of its + // subtypes. + template + T* as() { + static_assert((std::is_base_of::value), + "target type not derived from source type"); + // We skip the CHECK and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + CHECK_NE(dynamic_cast(this), nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + + return static_cast(this); + } + + virtual ~Array() = default; + + Array& operator=(const Array& other) = delete; + }; + + // Represents an HLO instruction that was not analyzable by this + // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing + // HloInstruction. + class UnknownArray : public Array { + public: + Kind kind() const override { return kUnknown; } + const Shape& shape() const override { return instruction().shape(); } + const HloInstruction& instruction() const { return instruction_; } + + private: + explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {} + + const HloInstruction& instruction_; + + friend class IndexedArrayAnalysis; + }; + + // Represents a constant value. This constant value may be present in the HLO + // module being analyzed, or it could have been created on the fly by the + // analysis. + class ConstantArray : public Array { + public: + Kind kind() const override { return kConstant; } + const Shape& shape() const override { return literal()->shape(); } + const Literal* literal() const { return literal_; } + + private: + explicit ConstantArray(const Literal* literal) : literal_(literal) {} + const Literal* literal_; + + friend class IndexedArrayAnalysis; + }; + + // --------------------------------------------------------------------------- + // Indexed Array Overview + // --------------------------------------------------------------------------- + // + // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this + // analysis. ScalarIndexedConstantArray is just a specialization of + // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this + // overview. + // + // A ScalarIndexedArray represents an array that can be computed by indexing + // into a "source" array using an "indices" tensor. A simple example is a + // gather operation gathering 12 rows out of a [100,100] matrix -- such an + // operation will be represented by an instance of a ScalarIndexedArray with + // the [100,100] matrix as the "source" array and the [12]-shaped indices + // array as the "indices" tensor. The ScalarIndexedArray operation itself + // will be of shape [12,100] (assuming we were gathering with axis=0). + // + // Gather operations are not the only operation that maps to + // ScalarIndexedArray instances (if that were true there would be little point + // in having a separate analysis). We can often infer ScalarIndexedArrays for + // other operations too. For instance, consider: + // + // %source = f32[100,100] constant + // %indices = s32[12] ... + // %gather = f32[12,100] ... gather from %source using %indices at axis 0 + // %dot = dot(%gather, other_constant) [canonical contracting dims] + // + // The dot operation itself is also a ScalarIndexedArray with source = + // dot(constant, other_constant) and indices = %indices. A reshape of %gather + // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately + // reshaped constant and indices = %indices. + + // Represents the result of a gather operation. This gather operation may + // explicitly be present in the HLO module being analyzed, or it could have + // been created on the fly by the analysis. + // + // An instance of ScalarIndexedArray represents a array whose I'th element can + // be mapped to the J'th element of the `source` array (where I and J are + // multidimensional indices) in this way: + // + // I' = remove components at positions `output_dims` from I + // G' = remove components not at positions `output_dims` from I + // T = indices[G'] + // J = I' with T inserted at position `source_dim` + // + // For example, if source is of shape [11,13,17,19], indices is of shape + // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of + // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the + // input index [B,D,indices[A,C],E]. + class ScalarIndexedArray : public Array { + public: + Kind kind() const override { return kScalarIndexed; } + const Shape& shape() const override { return shape_; } + + Array* source() const { return source_; } + Array* indices() const { return indices_; } + + // `source_dim` is the dimension in the source array that is being indexed + // over using indices from the `indices` array. See the class documentation + // and the overview for more details. + int64 source_dim() const { return source_dim_; } + + // `output_dims` are the dimensions in the output array that are being used + // to compute an index into the `indices` array. See the class + // documentation and the overview for more details. + tensorflow::gtl::ArraySlice output_dims() const { + return output_dims_; + } + + private: + explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) + : source_(source), + indices_(indices), + source_dim_(source_dim), + output_dims_(std::move(output_dims)), + shape_(std::move(shape)) {} + + Array* source_; + Array* indices_; + int64 source_dim_; + std::vector output_dims_; + Shape shape_; + + friend class IndexedArrayAnalysis; + }; + + // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to + // have a ConstantArray instance as the source. This is an ergonomic + // concession -- in theory it is possible to just keep ScalarIndexedArray and + // check source()->kind(). + class ScalarIndexedConstantArray : public ScalarIndexedArray { + public: + Kind kind() const override { return kScalarIndexedConstant; } + + const Literal& literal() const { + return *source()->as()->literal(); + } + + private: + explicit ScalarIndexedConstantArray(Array* source, Array* indices, + int64 source_dim, + std::vector output_dims, + Shape shape) + : ScalarIndexedArray(source, indices, source_dim, + std::move(output_dims), std::move(shape)) { + CHECK(dynamic_cast(source)); + } + + friend class IndexedArrayAnalysis; + }; + + // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance + // keeps ownership of the returned Array instance. + // + // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO + // instructions to IndexedArrayAnalysis::Array instances. This entire cache + // becomes stale and may cause the analysis to return incorrect results if any + // transitive operand (stopping at the containing computation) is modified for + // any HLO instruction on which GetArrayFor has been invoked. + // + // NB! By inspecting the implementation, you may be able to infer a stronger + // caching guarantee than what is mentioned above. Nevertheless, what is + // stated above is the contract. + StatusOr GetArrayFor(const HloInstruction* instr); + + // Pretty-prints the expression rooted at `root`. + string ToString(Array* root, bool print_constants = false); + + private: + // Helper function that ensures that every HLO instruction that is + // transitively used by `root` has an entry in `cache_`. + Status TraverseAndPopulateCache(const HloInstruction* root); + + // Creates an Array instance for `instr` under the assumption that all + // operations of `instr` are present in `cache_`. + StatusOr ComputeArrayFor(const HloInstruction* instr); + + StatusOr ComputeArrayForConstant(const Literal& literal); + + StatusOr ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices); + + // This tries to fold a ScalarIndexedArray which has another + // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a + // ScalarIndexedArray as indices. If `source` happened to be a + // ScalarIndexedConstantArray this can result in an expression that is more + // canonical. + // + // As an example, consider a gather operation, G0, gathering 7 elements from + // an array "Arr" of shape [100] resulting in an array of shape [7], and a + // second gather operation, G1, which gathers 3 elements out of the result of + // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0 + // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can + // instead rewrite G1 to gather directly from "Arr" with the three indices + // from I0 as per I1. In other words, we can rewrite: + // + // G0 = [Arr[i] for i in I0] + // G1 = [G0[i] for i in I1] + // + // into + // + // I2 = [I0[i] for i in I1] + // G1 = [Arr[i] for i in I2] + StatusOr FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape); + + StatusOr ComputeArrayForReshape(const Shape& shape, Array* operand); + + StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, Array* rhs); + StatusOr ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand); + + template + T* Construct(Args&&... args) { + T* new_tensor = new T(std::forward(args)...); + owned_tensors_.push_back(std::unique_ptr(new_tensor)); + return new_tensor; + } + + ScalarIndexedArray* ConstructScalarIndexedArray( + Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) { + if (source->kind() == Array::kConstant) { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } else { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } + } + + Literal* TakeOwnership(std::unique_ptr literal) { + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + + StatusOr TakeOwnership( + StatusOr> literal_or_error) { + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + std::move(literal_or_error)); + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + + std::vector> owned_tensors_; + std::vector> owned_literals_; + tensorflow::gtl::FlatMap cache_; +}; + +// A pass that prints all non-trivial results returned by IndexedArrayAnalysis. +// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to +// unconditionally add to the regular HLO pass pipeline. +class IndexedArrayAnalysisPrinterPass : public HloPassInterface { + public: + tensorflow::StringPiece name() const override; + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..373556ebeba883f7dc2116bdf0ffc3274182f775 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -0,0 +1,504 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" + +namespace xla { +namespace { +class IndexedArrayAnalysisTest : public HloVerifiedTestBase { + protected: + void AssertArrayForRootExpressionIs(const string& hlo_text, + const string& root_expression) { + AssertArrayForRootExpressionIsImpl(hlo_text, root_expression, + /*print_constants=*/false); + } + + void AssertArrayWithConstantsForRootExpressionIs( + const string& hlo_text, const string& root_expression) { + AssertArrayForRootExpressionIsImpl(hlo_text, root_expression, + /*print_constants=*/true); + } + + private: + void AssertArrayForRootExpressionIsImpl(const string& hlo_text, + const string& root_expression, + bool print_constants) { + IndexedArrayAnalysis indexed_tensor_analysis; + ParseAndVerifyModule(hlo_text); + + TF_ASSERT_OK_AND_ASSIGN( + IndexedArrayAnalysis::Array* const array_result, + indexed_tensor_analysis.GetArrayFor( + module().entry_computation()->root_instruction())); + string string_result = + indexed_tensor_analysis.ToString(array_result, print_constants); + LOG(INFO) << string_result; + ASSERT_EQ(string_result, root_expression); + } +}; + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices = s32[5] parameter(0) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices_a = s32[5] parameter(0) + indices_b = s32[2] parameter(1) + gather_a = s32[5,3] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} + ROOT gather_b = s32[2,3] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,3]) (scalar-indexed %indices_a " + "%indices_b 0->[0]) 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithOneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[2] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b), + output_window_dims={0,1}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=1, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 1->[1]) 1->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,6] parameter(0) + indices_a = s32[2] parameter(1) + indices_b = s32[5,7] parameter(2) + gather_a = s32[2,6] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,6} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 0->[0,1]) 0->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[4,8] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b), + output_window_dims={1,2}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=2, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed %operand (scalar-indexed %indices_a %indices_b " + "1->[0,2]) 1->[0,1,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5] parameter(0) + gather = s32[5,4] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT reshape = s32[5,2,2] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5,7] parameter(0) + gather = s32[5,4,7] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4} + ROOT reshape = s32[5,2,2,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,2,6] constant(s32[3,2,6]{ + {{1,2,3,4,5,6},{1,2,3,4,5,6}}, + {{1,2,3,4,5,6},{1,2,3,4,5,6}}, + {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) + indices = s32[5,7] parameter(0) + gather = s32[5,2,6,7] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,2,6} + ROOT reshape = s32[5,3,4,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative0) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5,6] parameter(0) + gather = s32[5,4,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4} + ROOT reshape = s32[5,2,2,2,3] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%reshape"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative1) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,5,2] constant(s32[3,5,2]{ + {{1,2},{3,4},{5,6},{7,8},{9,10}}, + {{1,2},{3,4},{5,6},{7,8},{9,10}}, + {{1,2},{3,4},{5,6},{7,8},{9,10}}}) + indices = s32[7] parameter(0) + gather = s32[3,2,7] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1,2} + ROOT reshape = s32[6,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%reshape"); +} + +TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { + string hlo_text = R"( +HloModule UnaryOpOfGather + +ENTRY main { + operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + indices = s32[5] parameter(0) + gather = f32[5,4] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT tanh = f32[5,4] tanh(gather) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant f32[3,4] f32[3,4] { + { 0.761594176, 0.964027584, 0.995054781, 0.999329329 }, + { 0.761594176, 0.995054781, 0.964027584, 0.999329329 }, + { 0.999329329, 0.995054781, 0.964027584, 0.761594176 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) { + string hlo_text = R"( +HloModule AddBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 6, 7, 8, 9 }, + { 6, 8, 7, 9 }, + { 9, 8, 7, 6 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, + SubtractBroadcastedScalarWithGather_GatherIsLhs) { + string hlo_text = R"( +HloModule SubtractBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT sub = s32[5,4] subtract(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { -4, -3, -2, -1 }, + { -4, -2, -3, -1 }, + { -1, -2, -3, -4 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, + SubtractBroadcastedScalarWithGather_GatherIsRhs) { + string hlo_text = R"( +HloModule SubtractBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT sub = s32[5,4] subtract(constant_broadcasted, gather) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 4, 3, 2, 1 }, + { 4, 2, 3, 1 }, + { 1, 2, 3, 4 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) { + string hlo_text = R"( +HloModule AddBroadcastedVectorWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant_vect = s32[4] constant({10,11,12,13}) + constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 11, 13, 15, 17 }, + { 11, 14, 14, 17 }, + { 14, 14, 14, 14 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) { + string hlo_text = R"( +HloModule AddBroadcastedVectorWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant_vect = s32[5] constant({10,11,12,13,14}) + constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%add"); +} + +TEST_F(IndexedArrayAnalysisTest, RegularUnaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input = f32[100] parameter(0) + ROOT tanh = f32[100] tanh(input) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%tanh"); +} + +TEST_F(IndexedArrayAnalysisTest, RegularBinaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input0 = f32[100] parameter(0) + input1 = f32[100] parameter(1) + ROOT add = f32[100] add(input0, input1) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%add"); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 7aa1c7c8358318d02a000d968a2672123400ad6e..d2af261008f40ee83e0676cfc7e67c45f8be1844 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR1({4, 3, 3, 4}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } // Test that `constant` function is changed to `broadcast`. @@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } TEST_F(InlinerTest, MapSubtractOppositeOrder) { @@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR1({3, 1, -1, -3}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 6bb2ca19fe235d61dfad2c7cde2f31c797628c1d..abedb4063d3763516e66cff36633dbd90c8cafde 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -96,6 +96,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: + case HloOpcode::kGenerateToken: case HloOpcode::kTranspose: case HloOpcode::kTuple: return false; @@ -118,13 +119,16 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kDivide: + case HloOpcode::kDomain: case HloOpcode::kDot: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kGather: case HloOpcode::kHostCompute: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kMap: case HloOpcode::kParameter: case HloOpcode::kPower: @@ -176,8 +180,7 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { bool InstructionFusion::CanFuseOnAllPaths( HloInstruction* producer, HloInstruction* consumer, - const HloReachabilityMap& reachability_map, - const DoNotFuseSet& do_not_fuse) { + const HloInstructionSet& do_not_duplicate) { if (consumer == producer) { return true; } @@ -188,10 +191,11 @@ bool InstructionFusion::CanFuseOnAllPaths( auto* consumer_operand = consumer->mutable_operand(i); // If the operand is not on a path to the producer, it doesn't matter // whether it's fusable. - if (!reachability_map.IsReachable(producer, consumer_operand)) { + if (!reachability_->IsReachable(producer, consumer_operand)) { continue; } - if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) { + if (do_not_duplicate.count(consumer_operand) > 0 || + !ShouldFuse(consumer, i)) { return false; } // The producer is reachable from consumer_operand which means we need @@ -199,18 +203,16 @@ bool InstructionFusion::CanFuseOnAllPaths( // producer to be fusable into consumer on all paths. // Perform the recursive step: make sure producer can be fused into // consumer_operand on all paths. - if (!CanFuseOnAllPaths(producer, consumer_operand, reachability_map, - do_not_fuse)) { + if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { return false; } } return true; } -InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable( +InstructionFusion::HloInstructionSet +InstructionFusion::ComputeGloballyUnfusable( tensorflow::gtl::ArraySlice post_order) { - auto reachability = computation_->ComputeReachability(); - // Forbid fusion of producers that: // a) Need to be duplicated, unless they can be fused into all consumers // via all paths. @@ -220,10 +222,10 @@ InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable( // Note that if we allow fusion by these global rules, we may still forbid // fusing operations that require duplication later depending on // is_expensive_(). - DoNotFuseSet do_not_fuse; + HloInstructionSet do_not_duplicate; for (HloInstruction* consumer : post_order) { for (HloInstruction* producer : consumer->operands()) { - if (do_not_fuse.count(producer) > 0) { + if (do_not_duplicate.count(producer) > 0) { continue; } @@ -252,14 +254,14 @@ InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable( // A will be not allowed to be fused into B, as it cannot be fused via // all paths. if (producer->IsFusable() && - CanFuseOnAllPaths(producer, consumer, *reachability, do_not_fuse)) { + CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { continue; } - do_not_fuse.insert(producer); + do_not_duplicate.insert(producer); } } - return do_not_fuse; + return do_not_duplicate; } StatusOr InstructionFusion::Run(HloModule* module) { @@ -271,6 +273,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { CHECK(!computation->IsFusionComputation()); computation_ = computation; + reachability_ = computation_->ComputeReachability(); // We want to be able to remove arbitrary instructions from the post order // and also compare positions of instructions in the post order. To make @@ -288,7 +291,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { InsertOrDie(&post_order_index, post_order[i], i); } - DoNotFuseSet do_not_fuse = ComputeGloballyUnfusable(post_order); + HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -356,9 +359,20 @@ StatusOr InstructionFusion::Run(HloModule* module) { // ensures that B will be considered before A. // // We store the original indices of the operands to pass to ShouldFuse. - std::vector sorted_operand_numbers(instruction->operands().size()); - std::iota(std::begin(sorted_operand_numbers), - std::end(sorted_operand_numbers), 0); + std::vector sorted_operand_numbers; + sorted_operand_numbers.reserve(instruction->operands().size()); + for (int i = 0; i < instruction->operands().size(); ++i) { + // This will happen if we have two possible instructions to fuse the + // same operand into; once the operand is fused into one instruction, + // the other instruction will get a new get-tuple-element as its + // operand, which is not in the post-order index. + // TODO(tjoerg): Look into fusing past these multi-output fuse points. + if (post_order_index.find(instruction->mutable_operand(i)) == + post_order_index.end()) { + continue; + } + sorted_operand_numbers.push_back(i); + } std::sort( sorted_operand_numbers.begin(), sorted_operand_numbers.end(), [&](int64 i, int64 j) { @@ -375,13 +389,20 @@ StatusOr InstructionFusion::Run(HloModule* module) { if (!operand->IsFusable()) { continue; } - if (!ShouldFuse(instruction, i)) { - continue; - } - if (do_not_fuse.count(operand) > 0) { + + HloInstruction* fusion_instruction; + // Try "regular" fusion if the operand may be duplicated. Otherwise, + // perform multi-output fusion, unless this creates a cycle. + // TODO(tjoerg): Consider making multi-output fusion the default. + if (ShouldFuse(instruction, i) && + do_not_duplicate.count(operand) == 0) { + fusion_instruction = Fuse(operand, instruction); + } else if (ShouldFuseIntoMultiOutput(instruction, i) && + !MultiOutputFusionCreatesCycle(operand, instruction)) { + fusion_instruction = FuseIntoMultiOutput(operand, instruction); + } else { continue; } - HloInstruction* fusion_instruction = Fuse(operand, instruction); // Fusing an instruction into a fusion instruction can change the // operand set of the fusion instruction. For simplicity just push the @@ -412,12 +433,9 @@ StatusOr InstructionFusion::Run(HloModule* module) { return changed; } -HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, - HloInstruction* consumer) { +HloInstruction* InstructionFusion::AddFusionInstruction( + HloInstruction* producer, HloInstruction* consumer) { HloInstruction* fusion_instruction; - - VLOG(2) << "Fusing " << producer->ToString() << " into " - << consumer->ToString(); auto kind = ChooseKind(producer, consumer); if (consumer->opcode() == HloOpcode::kFusion) { fusion_instruction = consumer; @@ -429,11 +447,40 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction)); } + return fusion_instruction; +} +HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, + HloInstruction* consumer) { + VLOG(2) << "Fusing " << producer->ToString() << " into " + << consumer->ToString(); + HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); fusion_instruction->FuseInstruction(producer); return fusion_instruction; } +HloInstruction* InstructionFusion::FuseIntoMultiOutput( + HloInstruction* producer, HloInstruction* consumer) { + VLOG(2) << "Multi-output fusing " << producer->ToString() << " into " + << consumer->ToString(); + HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); + fusion_instruction->FuseInstructionIntoMultiOutput(producer); + return fusion_instruction; +} + +bool InstructionFusion::MultiOutputFusionCreatesCycle( + HloInstruction* producer, HloInstruction* consumer) { + return c_any_of( + consumer->operands(), [&](const HloInstruction* consumer_operand) { + // The fusion algorithm traverses the HLO graph in reverse post order. + // Thus `cosumers` is visited before its operands (including + // `producer`). Therefore, consumer operands cannot have been fused yet. + // It is thus safe to use the pre-computed reachability map. + return consumer_operand != producer && + reachability_->IsReachable(producer, consumer_operand); + }); +} + bool InstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 2ea1fcf937ceaf2cce3f8ed0891399384d93dbd0..f73ca9adf768ed26f9ec9f162e01b7b160f50daf 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -61,6 +61,14 @@ class InstructionFusion : public HloPassInterface { // Subtypes can override this with target-specific heuristics. virtual bool ShouldFuse(HloInstruction* consumer, int64 operand_index); + // Returns whether multi-output fusion can be applied to fuse `producer` into + // `consumer`. In contrast to "regular" fusion, the `producer` is not + // duplicated by multi-output fusion. + virtual bool ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) { + return false; + } + // Chooses a fusion kind for `producer` and `consumer`. // Default method chooses `kLoop`. virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, @@ -70,6 +78,13 @@ class InstructionFusion : public HloPassInterface { virtual HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); + // Creates a new fusion instruction containing `producer` and `consumer`. A + // tuple is added as the fusion instruction's root, which consumes from both, + // `producer` and `consumer`. This style of fusion is referred to as + // multi-output fusion. + virtual HloInstruction* FuseIntoMultiOutput(HloInstruction* producer, + HloInstruction* consumer); + // An "effectively unary" operation is one that has at most one "large" // input with the others being negligible in terms of memory usage. // We use "has a smaller true rank than the output" as a heuristic @@ -90,26 +105,34 @@ class InstructionFusion : public HloPassInterface { // Current HloComputation instance the loop fuser is traversing. HloComputation* computation_; HloModule* module_; + // Reachability information for the current computation. + std::unique_ptr reachability_; private: // The set of producers whose consumers we cannot fuse into. - using DoNotFuseSet = std::unordered_set; + using HloInstructionSet = std::unordered_set; + + HloInstruction* AddFusionInstruction(HloInstruction* producer, + HloInstruction* consumer); // Whether or not we can fuse producer into consumer on all paths // from the producer to the consumer where nodes are HLOs and edges are uses. bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer, - const HloReachabilityMap& reachability_map, - const DoNotFuseSet& do_not_fuse); + const HloInstructionSet& do_not_fuse); // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. - DoNotFuseSet ComputeGloballyUnfusable( + HloInstructionSet ComputeGloballyUnfusable( tensorflow::gtl::ArraySlice post_order); // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. std::function is_expensive_; + // Whether multi-output fusion would introduce a cycle into the HLO graph. + bool MultiOutputFusionCreatesCycle(HloInstruction* producer, + HloInstruction* consumer); + // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 6dd8fa1ab08737f0d77c47a1f8ed59a85b4f2bbd..21db2338995960bde00ec9c4b325e5562fc3a592 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { @@ -25,6 +25,91 @@ namespace op = xla::testing::opcode_matchers; using InstructionFusionTest = HloTestBase; +// Subclass of InstructionFusion exposing the protected methods Fuse and +// FuseIntoMultiOutput for testing. +class InstructionFusionForTesting : public InstructionFusion { + public: + explicit InstructionFusionForTesting(HloModule* module) + : InstructionFusion(InstructionFusion::IsExpensive) { + module_ = module; + computation_ = module->entry_computation(); + } + + HloInstruction* Fuse(HloInstruction* producer, + HloInstruction* consumer) override { + return InstructionFusion::Fuse(producer, consumer); + } + + HloInstruction* FuseIntoMultiOutput(HloInstruction* producer, + HloInstruction* consumer) override { + return InstructionFusion::FuseIntoMultiOutput(producer, consumer); + } +}; + +TEST_F(InstructionFusionTest, FuseInstructions) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY entry_computation { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + ROOT sub = f32[4,3]{1,0} subtract(add, p0) + })") + .ValueOrDie(); + HloInstruction* sub = module->entry_computation()->root_instruction(); + HloInstruction* add = sub->mutable_operand(0); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).Fuse(add, sub); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), + op::Subtract(op::Add(), op::Parameter())) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { + auto module = ParseHloString(R"( + HloModule test_module + fused_computation { + p1 = f32[4,3] parameter(0) + add = f32[4,3] add(p1, p1) + } + ENTRY entry_computation { + p0 = f32[4,3] parameter(0) + abs = f32[4,3] abs(p0) + ROOT fusion = f32[4,3] fusion(abs), kind=kLoop, calls=fused_computation + })") + .ValueOrDie(); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* abs = root->mutable_operand(0); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).Fuse(abs, root); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), op::Add(op::Abs(), op::Abs())) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY entry_computation { + p0 = f32[4,3]{1,0} parameter(0) + abs = f32[4,3]{1,0} abs(p0) + tanh = f32[4,3]{1,0} tanh(abs) + ROOT add = f32[4,3]{1,0} add(abs, tanh) + })") + .ValueOrDie(); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* abs = root->mutable_operand(0); + HloInstruction* tanh = root->mutable_operand(1); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).FuseIntoMultiOutput(abs, tanh); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), op::Tuple(op::Tanh(), op::Abs())) + << module->ToString(); +} + TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( @@ -92,7 +177,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { EXPECT_FALSE( InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()); + .ValueOrDie()) + << module->ToString(); } // Counts the number of HLO ops with a given op code in the specified module. @@ -109,7 +195,7 @@ static int Count(const HloModule& module, HloOpcode op) { } TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -134,7 +220,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // // p0 -> add -------------------------> sub // \-> abs1 -> rng -> abs2 -/ - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -151,7 +237,11 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { .Run(module.get()) .ValueOrDie()) << module->ToString(); - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Subtract(op::Abs(op::Parameter()), op::Parameter())) + << module->ToString(); // Make sure the add hasn't been duplicated. EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); @@ -161,7 +251,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // p0 -> add -------------------------> sub // \-> abs1 -> log -> abs2 -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -192,7 +282,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // \ \-> add2 -/ // \-> log -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -224,7 +314,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // \------> sub1 // log -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -244,7 +334,12 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { .Run(module.get()) .ValueOrDie()) << module->ToString(); - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Tuple(op::Subtract(op::Parameter(), op::Parameter()), + op::Subtract(op::Parameter(), op::Parameter()))) + << module->ToString(); // Make sure we didn't duplicate any adds. EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); @@ -295,7 +390,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { TEST_F(InstructionFusionTest, WideningConvertsAreAlwaysDuplicableIntoConsumers) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY Test { p0 = f16[100] parameter(0) diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 45505484951abfcee93a62fec7a99e86cbb9150c..524d3234eb4eff9c7d000eca1a0d9f5c4fae90af 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -18,7 +18,6 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform_id", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -117,6 +116,5 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", - "//tensorflow/core:stream_executor_no_cuda", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index eecbbcb93df64b09acb5e009d3db79e51dab0c93..c1666530687f2f8407a9dcb4e271c9d95552a689 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/types.h" @@ -45,8 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->device_entry_computation_layout()); - + hlo_module->mutable_device_entry_computation_layout()); return pipeline.Run(hlo_module).status(); } @@ -71,7 +69,8 @@ StatusOr> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr executable = - xla::MakeUnique(std::move(hlo_module)); + xla::MakeUnique(std::move(hlo_module), + xla::MakeUnique()); return std::move(executable); } @@ -101,17 +100,14 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() return InterpreterExecutable::ShapeSizeBytes; } -static std::unique_ptr CreateComputationPlacer() { - return xla::MakeUnique(); -} - static bool InitModule() { xla::Compiler::RegisterCompilerFactory( se::interpreter::kXlaInterpreterPlatformId, []() { return xla::MakeUnique(); }); xla::ComputationPlacer::RegisterComputationPlacer( - se::interpreter::kXlaInterpreterPlatformId, &CreateComputationPlacer); + se::interpreter::kXlaInterpreterPlatformId, + []() { return xla::MakeUnique(); }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 61f199bc9e8f4f95a2f097af4abf9395a1e05f64..029e71058a7373b9310c6d9ffdb65f72ca28e5af 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -32,16 +31,17 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace interpreter { InterpreterExecutable::InterpreterExecutable( - std::unique_ptr hlo_module) + std::unique_ptr hlo_module, + std::unique_ptr evaluator) : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, - /*hlo_profile_index_map=*/nullptr) {} + /*hlo_profile_index_map=*/nullptr), + evaluator_(std::move(evaluator)) {} InterpreterExecutable::~InterpreterExecutable() {} @@ -82,10 +82,13 @@ StatusOr InterpreterExecutable::ExecuteOnStream( } // Execute the graph using the HloEvaluator. - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - std::unique_ptr result_literal, - evaluator.Evaluate>(*computation, arg_literals)); + std::unique_ptr result_literal; + { + tensorflow::mutex_lock lock(evaluator_lock_); + TF_ASSIGN_OR_RETURN(result_literal, + evaluator_->Evaluate>( + *computation, arg_literals)); + } // Transform the result literal back into a ShapedBuffer. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index b0b797ca7d6f449a11c662ffba7c2a0a0040e47e..91d8148d26dc8eddbafdaf4870d9efbb73a12816 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -40,13 +42,15 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module); + InterpreterExecutable(std::unique_ptr hlo_module, + std::unique_ptr evaluator); ~InterpreterExecutable() override; StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) override; + HloExecutionProfile* hlo_execution_profile) override + LOCKS_EXCLUDED(evaluator_lock_); StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, @@ -54,6 +58,11 @@ class InterpreterExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); + protected: + // The interpreter interprets executables with an HloEvaluator. + std::unique_ptr evaluator_ PT_GUARDED_BY(evaluator_lock_); + mutable tensorflow::mutex evaluator_lock_; + private: TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); }; diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index 92e069a8c67c1d441ba9d396dee503c9b3bde0df..42c2c28997d5f3b02f1fe4effca164c893e4071d 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/interpreter/executor.h" -#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" @@ -31,13 +30,13 @@ limitations under the License. namespace stream_executor { namespace interpreter { -XlaInterpreterPlatform::XlaInterpreterPlatform() : name_("Interpreter") {} +XlaInterpreterPlatform::XlaInterpreterPlatform(const string& name, + const Platform::Id& id) + : name_(name), id_(id) {} XlaInterpreterPlatform::~XlaInterpreterPlatform() {} -Platform::Id XlaInterpreterPlatform::id() const { - return kXlaInterpreterPlatformId; -} +Platform::Id XlaInterpreterPlatform::id() const { return id_; } int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } @@ -106,8 +105,6 @@ REGISTER_MODULE_INITIALIZER( interpreter_platform, stream_executor::interpreter::InitializeXlaInterpreterPlatform()); -DECLARE_MODULE_INITIALIZER(multi_platform_manager); - // Note that module initialization sequencing is not supported in the // open-source project, so this will be a no-op there. REGISTER_MODULE_INITIALIZER_SEQUENCE(interpreter_platform, diff --git a/tensorflow/compiler/xla/service/interpreter/platform.h b/tensorflow/compiler/xla/service/interpreter/platform.h index d68c5aa20dda7ac246ed4aa667851e385a604c04..0187f6d473b19f50136e214708e56f833627d9d1 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.h +++ b/tensorflow/compiler/xla/service/interpreter/platform.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -28,7 +29,8 @@ namespace interpreter { class XlaInterpreterPlatform : public Platform { public: - XlaInterpreterPlatform(); + XlaInterpreterPlatform(const string& name = "Interpreter", + const Platform::Id& id = kXlaInterpreterPlatformId); ~XlaInterpreterPlatform() override; Platform::Id id() const override; @@ -55,6 +57,8 @@ class XlaInterpreterPlatform : public Platform { private: // This platform's name. string name_; + // This platform's id. + Platform::Id id_; // Cache of created StreamExecutors. ExecutorCache executor_cache_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index cfa7ba5e81ddd003978a2bd763384581c55b5c83..7067b6f86a0fb24fb946ad236bca9bbd48d53722 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -31,10 +31,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -400,9 +402,9 @@ string LayoutConstraints::ToString() const { } Status LayoutAssignment::AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints) { + const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, HloComputation* computation, + LayoutConstraints* constraints) { VLOG(3) << "Adding mandatory layout constraints to computation " << computation->name(); @@ -424,11 +426,16 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( instruction->outfeed_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kParameter) { - // Parameter layouts must match the respective layout in - // ComputationLayout. - shape_with_layout = - &computation_layout.parameter_layout(instruction->parameter_number()) - .shape(); + if (computation_layout != nullptr) { + const ShapeLayout& parameter_layout = + computation_layout->parameter_layout( + instruction->parameter_number()); + if (parameter_layout.LayoutIsSet()) { + // Parameter layouts must match the respective layout in + // ComputationLayout, if there is one. + shape_with_layout = ¶meter_layout.shape(); + } + } } if (shape_with_layout != nullptr) { TF_RETURN_IF_ERROR( @@ -493,9 +500,8 @@ Status LayoutAssignment::AddMandatoryConstraints( HloComputation* body = instruction->while_body(); HloComputation* condition = instruction->while_condition(); const HloInstruction* init = instruction->operand(0); - const ComputationLayout& body_layout = - FindOrDie(computation_layouts_, body); - const ComputationLayout& condition_layout = + ComputationLayout& body_layout = FindOrDie(computation_layouts_, body); + ComputationLayout& condition_layout = FindOrDie(computation_layouts_, condition); // Check a few invariants irrespective of layout. @@ -508,26 +514,19 @@ Status LayoutAssignment::AddMandatoryConstraints( condition_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape())); - // Return error if earlier layout assignment of the embedded computations - // has produced conflicting layouts. - if (!ShapeUtil::Equal(body_layout.result_shape(), - body_layout.parameter_shape(0))) { - return InternalError( - "Parameter and result of body computation %s of while instruction " - "%s have different layouts: %s vs %s", - body->name().c_str(), instruction->name().c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str(), - ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str()); + if (body_layout.result_layout() != body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while body parameter layout: body=" << body->name() + << " while=" << instruction->name() + << " shape=" << body_layout.result_layout().ToString(); + *body_layout.mutable_parameter_layout(0) = body_layout.result_layout(); } - if (!ShapeUtil::Equal(body->root_instruction()->shape(), - condition->parameter_instruction(0)->shape())) { - return InternalError( - "Parameter of condition computation %s of while instruction " - "%s does not match body computation %s result: %s vs %s", - condition->name().c_str(), instruction->name().c_str(), - body->name().c_str(), - ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str()); + if (condition_layout.parameter_layout(0) != + body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while condition parameter layout: cond=" + << condition->name() << " while=" << instruction->name() + << " shape=" << body_layout.parameter_layout(0).ToString(); + *condition_layout.mutable_parameter_layout(0) = + body_layout.parameter_layout(0); } // Constrain the output and the operand of the while instruction to match @@ -557,7 +556,20 @@ Status LayoutAssignment::AddMandatoryConstraints( true_computation_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible( false_operand->shape(), false_computation_layout.parameter_shape(0))); - + if (true_computation_layout.result_layout() != + false_computation_layout.result_layout()) { + // We assign layouts in DFS fashion, so the true and false computations + // might have negotiated a different layout. But for the conditional + // instruction POV the layout must match, so we run again on the false + // computation, this time with proper computation layout. + VLOG(2) << "Reset %conditional false computation result layout: " + "false_computation=" + << false_computation->name() + << " conditional=" << instruction->name() << " shape=" + << true_computation_layout.result_layout().ToString(); + *false_computation_layout.mutable_result_layout() = + true_computation_layout.result_layout(); + } TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( true_computation_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( @@ -593,10 +605,14 @@ Status LayoutAssignment::AddMandatoryConstraints( } } } - - // Finally set the result layout to match ComputationLayout. - return constraints->SetResultLayout( - computation_layout.result_layout().shape()); + // Finally set the result layout to match ComputationLayout, if there is one. + if (computation_layout != nullptr) { + const ShapeLayout& result_layout = computation_layout->result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape())); + } + } + return Status::OK(); } namespace { @@ -760,6 +776,7 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( HloInstruction* copy = instruction->parent()->AddInstruction(HloInstruction::CreateUnary( instruction->shape(), HloOpcode::kCopy, instruction)); + RegisterAddedCopy(copy); SetupCopiedInstruction(*instruction, copy, {}); LayoutUtil::ClearLayout(copy->mutable_shape()); TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( @@ -783,13 +800,19 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer( TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + VLOG(5) << "Operand " << operand->ToString() << " layout matches in " + << instruction->ToString(); // Operand layout already matches our constraint. Nothing to do. return Status::OK(); } + VLOG(4) << "Operand " << operand->ToString() << " layout does not match " + << operand_layout.ToString() << " in " << instruction->ToString(); TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, CreateCopyWithNewLayout(operand_layout.shape(), operand)); + VLOG(4) << "New copy of " << operand->ToString() << " is " + << operand_copy->ToString(); return instruction->ReplaceOperandWith(operand_no, operand_copy); } @@ -896,32 +919,31 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { } } } - - // Finally verify the result layout matches the layout of the entry + // Finally verify the result layout, if set, matches the layout of the entry // computation root. - TF_RET_CHECK(ShapeUtil::Equal( - module->entry_computation()->root_instruction()->shape(), + const ShapeLayout& result_layout = FindOrDie(computation_layouts_, module->entry_computation()) - .result_layout() - .shape())); - + .result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RET_CHECK(ShapeUtil::Equal( + module->entry_computation()->root_instruction()->shape(), + result_layout.shape())); + } return Status::OK(); } LayoutAssignment::LayoutAssignment( - const ComputationLayout& entry_computation_layout, + ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), channel_layout_constraints_(channel_constraints) { - VLOG(1) << "entry computation layout given to layout assignment: " - << entry_computation_layout_.ToString(); + VLOG(1) << "Entry computation layout given to layout assignment: " + << entry_computation_layout_->ToString(); // Layouts of all parameter instructions must be set. for (const ShapeLayout& parameter_layout : - entry_computation_layout_.parameter_layouts()) { + entry_computation_layout_->parameter_layouts()) { CHECK(parameter_layout.LayoutIsSet()); } - // TODO(b/29118294): Choose a better layout if the result layout is not set. - CHECK(entry_computation_layout_.result_layout().LayoutIsSet()); } std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( @@ -1481,16 +1503,60 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, return Status::OK(); } +Status LayoutAssignment::CalculateComputationLayout( + HloComputation* computation) { + ComputationLayout computation_layout(computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + InsertOrDie(&computation_layouts_, computation, computation_layout); + VLOG(2) << " Calculated ComputationLayout = " + << computation_layout.ToString(); + return Status::OK(); +} + +Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { + // Clear existing layouts of the instructions. All layouts must be assigned + // by the LayoutAssignment pass, except for those on infeeds, parameters, + // and the computation result. The latter two are specified in + // computation_layout, so we only need to keep the existing layouts for + // infeeds. Clearing the layouts here avoids hiding potential bugs in the + // layout assignment pass that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBitcast) { + // bitcasts are inherently layout sensitive and so a bitcast instruction + // present in the IR before layout assignment is a bug. + return InternalError( + "Unexpected bitcast operation seen during layout assignment: %s.", + instruction->ToString().c_str()); + } + if (instruction->opcode() != HloOpcode::kInfeed) { + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + } + return Status::OK(); +} + Status LayoutAssignment::RunOnComputation( - const ComputationLayout& computation_layout, + ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints) { - DCHECK(computation_layout.LayoutIsSet()); - InsertOrDie(&computation_layouts_, computation, computation_layout); VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() << ")"; - VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); + TF_RETURN_IF_ERROR(ClearComputationLayouts(computation)); + if (computation_layout != nullptr) { + auto it = computation_layouts_.find(computation); + if (it == computation_layouts_.end()) { + VLOG(2) << " New ComputationLayout = " << computation_layout->ToString(); + computation_layouts_.emplace(computation, *computation_layout); + } else { + TF_RET_CHECK(computation_layout == &it->second || + computation_layout == entry_computation_layout_); + VLOG(2) << " Existing ComputationLayout = " + << computation_layout->ToString(); + } + } else { + VLOG(2) << " No ComputationLayout specified (will be calculated)"; + } // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(points_to_analysis, computation); @@ -1533,12 +1599,19 @@ Status LayoutAssignment::RunOnComputation( CHECK_LT(constraints.unconstrained_buffer_ids().size(), unconstrained_count); } - // All logical buffers should have constraints at this point. All that // remains is assign the constraints to the buffers and infer layouts for // aliased buffers. TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation)); + // If the computation layout wasn't specified, now it is the time to compute + // it according to the parameters and root instruction layouts. + // This allows the first pass through this API to record the best flowing + // layout to parameters and root instruction. + if (computation_layout == nullptr) { + TF_RETURN_IF_ERROR(CalculateComputationLayout(computation)); + } + // Record the layouts assigned for any communication ops in // channel_constraints so that they are constrained for future modules. for (HloInstruction* instruction : computation->instructions()) { @@ -1553,6 +1626,34 @@ Status LayoutAssignment::RunOnComputation( return Status::OK(); } +Status LayoutAssignment::PropagateComputationLayouts( + HloComputation* computation, ComputationLayout* computation_layout) { + ComputationLayout computed_computation_layout( + computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) { + ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i); + if (!param_layout->LayoutIsSet()) { + VLOG(4) << "Assigning layout to parameter " << i << " of computation " + << computation->name() << ": " + << computed_computation_layout.parameter_layout(i).ToString(); + *param_layout = computed_computation_layout.parameter_layout(i); + } else { + TF_RET_CHECK(computed_computation_layout.parameter_layout(i) == + *param_layout); + } + } + ShapeLayout* result_layout = computation_layout->mutable_result_layout(); + if (!result_layout->LayoutIsSet()) { + VLOG(4) << "Assigning result layout of computation " << computation->name() + << ": " << computed_computation_layout.result_layout().ToString(); + *result_layout = computed_computation_layout.result_layout(); + } else { + TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout); + } + return Status::OK(); +} + StatusOr LayoutAssignment::Run(HloModule* module) { VLOG(2) << "Running layout assignment on module " << module->name(); XLA_VLOG_LINES(3, module->ToString()); @@ -1561,52 +1662,45 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "before layout assignment", module->config().debug_options()); } - - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // Assign layouts to computations in an order such that a callee computation - // is handled before its caller computation. This ensures that the layout of - // all callers of a computation will agree. - std::list computation_post_order = - module->MakeComputationPostOrder(); - for (auto* computation : module->MakeComputationPostOrder()) { - if (computation->IsFusionComputation()) { - continue; - } - // Clear existing layouts of the instructions. All layouts must be assigned - // by the LayoutAssignment pass, except for those on infeeds, parameters, - // and the computation result. The latter two are specified in - // computation_layout, so we only need to keep the existing layouts for - // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidently use the existing layout. - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kBitcast) { - // bitcasts are inherently layout sensitive and so a bitcast instruction - // present in the IR before layout assignment is a bug. - return InternalError( - "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString().c_str()); + TF_RETURN_IF_ERROR(Init()); + + // We do two passes. The first one we pass a nullptr ComputationLayout to + // the RunOnComputation() calls (for non entry computations), and we register + // the ComputationLayout which are naturally flowing in DFS fashion to the + // parameters and root instruction. + // Walking in DFS mode though, means that we can end up with incorrect layouts + // when seen from an outer instruction, which has across-computation + // constraints to impose. + // For example, the kWhile instruction needs to enforce the same layouts for + // the parameters and root of the bosy, as well as the condition parameters. + // Similarly, the kConditional instruction needs to enforce the same layouts + // for the root of the true and false computations. + // So in the first pass, while allowing the layouts to flow to parameters and + // root, we also fix up the eventually inconsistent ComputationLayout, which + // will be then made mandatory by the second pass. + for (int64 i = 0; i < 2; ++i) { + TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module)); + TF_ASSIGN_OR_RETURN(auto points_to_analysis, + TuplePointsToAnalysis::Run(module)); + for (auto* computation : module->MakeComputationPostOrder()) { + if (computation->IsFusionComputation()) { + continue; } - if (instruction->opcode() != HloOpcode::kInfeed) { - LayoutUtil::ClearLayout(instruction->mutable_shape()); + if (computation == module->entry_computation()) { + TF_RETURN_IF_ERROR(RunOnComputation( + entry_computation_layout_, *points_to_analysis, + module->entry_computation(), channel_layout_constraints_)); + } else { + ComputationLayout* computation_layout = + (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation); + TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, + *points_to_analysis, computation, + channel_layout_constraints_)); } } - if (computation == module->entry_computation()) { - TF_RETURN_IF_ERROR(RunOnComputation( - entry_computation_layout_, *points_to_analysis, - module->entry_computation(), channel_layout_constraints_)); - } else { - ComputationLayout computation_layout(computation->ComputeProgramShape()); - // Setting all embedded computations to the default layout is potentially - // suboptimal. - computation_layout.SetToDefaultLayout(); - TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, - *points_to_analysis, computation, - channel_layout_constraints_)); - } } - + TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(), + entry_computation_layout_)); TF_RETURN_IF_ERROR(CheckLayouts(module)); VLOG(3) << "After layout assignment:"; @@ -1616,9 +1710,54 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "after layout assignment", module->config().debug_options()); } - // All layouts are reset then reassigned by this pass. return true; } +Status LayoutAssignment::Init() { + computation_layouts_.clear(); + return Status::OK(); +} + +Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { + // Clear all the copies which have been added, and all the related + // instructions (like GTE and tuples). + int64 removed_copies = 0; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kCopy && + added_copies_.count(instruction) > 0) { + VLOG(5) << "Removing added copy: " << instruction->ToString(); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + ++removed_copies; + } + } + } + added_copies_.clear(); + if (removed_copies > 0) { + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + return Status::OK(); +} + +Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction, + int64 operand_number) { + HloInstruction* operand = instruction->mutable_operand(operand_number); + if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + SetupCopiedInstruction(*operand, copy, {}); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy)); + } + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 9663a793fdd7d4968700707a1003319e89ea19a3..c287cca0c54ba1bb514bd8d243c137eca99b258f 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -288,7 +289,7 @@ class LayoutAssignment : public HloPassInterface { // If channel_constraints is nullptr, no kSend or kRecvs must be contained // within any module passed to `Run`. explicit LayoutAssignment( - const ComputationLayout& entry_computation_layout, + ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} tensorflow::StringPiece name() const override { return "layout-assignment"; } @@ -362,12 +363,15 @@ class LayoutAssignment : public HloPassInterface { int64 operand_no); private: + // Initializes the layout assignment object for a new Run() call. + Status Init(); + // Adds constraints which must be satisfied for correctness on all // backends. Called once prior to propagating constraints. - Status AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints); + Status AddMandatoryConstraints(const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, + HloComputation* computation, + LayoutConstraints* constraints); // This method can be overridden to add backend-specific constraints to the // layout of the instructions of a computation. This method is called after @@ -378,10 +382,12 @@ class LayoutAssignment : public HloPassInterface { } // Construct contraints and assign layouts to all instructions in the - // computation satisfying the given ComputationLayout. Layouts constraints are - // added, then propagated until all LogicalBuffers in the computation are - // constrained. - Status RunOnComputation(const ComputationLayout& computation_layout, + // computation satisfying the given ComputationLayout, if not nullptr. + // Otherwise the ComputationLayout will be calculated by propagating the + // computation instruction contraints. + // Layouts constraints are added, then propagated until all LogicalBuffers in + // the computation are constrained. + Status RunOnComputation(ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints); @@ -402,7 +408,26 @@ class LayoutAssignment : public HloPassInterface { // necessary conditions. Status CheckLayouts(HloModule* module); - const ComputationLayout& entry_computation_layout_; + // Computes the ComputationLayout of the given computation based of the + // layouts assigned to parameters and root instruction, and inserts it to the + // computation_layouts_ map. + Status CalculateComputationLayout(HloComputation* computation); + + // Clears all the layouts which can be cleared within a computation. + Status ClearComputationLayouts(HloComputation* computation); + + // Clears the side effects of a previous pass, like added copy instructions. + Status ClearPreviousPassSideEffects(HloModule* module); + + // Propagates the layouts computed by the layout assignment pass on the given + // computation, to the computation layout passed in to this API. + // This API propagates missing layout, and also checks that the caller + // specified have been respected, by comparing those with the parameters and + // root computation instruction. + Status PropagateComputationLayouts(HloComputation* computation, + ComputationLayout* computation_layout); + + ComputationLayout* entry_computation_layout_; protected: // Sets up the copy instruction according to the characteristic (sharding, @@ -418,21 +443,37 @@ class LayoutAssignment : public HloPassInterface { // Creates and returns a copy of the given instruction with a different // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple // instruction producing the copy is returned. - static StatusOr CreateCopyWithNewLayout( + StatusOr CreateCopyWithNewLayout( const Shape& shape_with_layout, HloInstruction* instruction); // Creates a copy of the given operand if the operand's layout does not match // the given layout. This copy replaces the use in the given instruction. // Tuple operands will be deep-copied. - static Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no); + Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no); + + // Registers a copy instruction added by the layout assignment pass. + void RegisterAddedCopy(HloInstruction* copy) { + CHECK_EQ(copy->opcode(), HloOpcode::kCopy); + added_copies_.insert(copy); + } + + // Adds a copy for the operand of an instruction, unless such operand is + // already a copy, and has a single user (which is forcibly the instruction + // itself). + Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number); // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller // instructions can be set to match the computation. std::map computation_layouts_; + + // Every copy added to the module by the layout assignment pass is registered + // here. + tensorflow::gtl::FlatSet added_copies_; + ChannelLayoutConstraints* channel_layout_constraints_; }; diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 7e1bb11eaada0e62b82c50903c9848f0a3a8307b..bf0448a67674f24591d866b646b98aea09ebb12c 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -29,13 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -53,7 +53,7 @@ class LayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout) { - LayoutAssignment layout_assignment(*entry_computation_layout); + LayoutAssignment layout_assignment(entry_computation_layout); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -285,7 +285,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - LayoutAssignment layout_assignment(computation_layout); + LayoutAssignment layout_assignment(&computation_layout); AssignLayouts(module.get(), &computation_layout); // Layout assignment should have deep copied the result of the computation to @@ -488,7 +488,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { public: explicit OperandsMustBeTheSameLayoutAssignment( ComputationLayout* entry_computation_layout) - : LayoutAssignment(*entry_computation_layout) {} + : LayoutAssignment(entry_computation_layout) {} protected: Status PropagateBufferConstraint( @@ -651,7 +651,7 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - auto module = tools::Parse(module_str).ValueOrDie(); + auto module = ParseHloString(module_str).ValueOrDie(); module = backend() @@ -660,13 +660,12 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - EXPECT_EQ( - ::tensorflow::Status::OK(), - backend() - .compiler() - ->RunBackend(std::move(module), backend().default_stream_executor(), - /*device_allocator=*/nullptr) - .status()); + EXPECT_EQ(Status::OK(), backend() + .compiler() + ->RunBackend(std::move(module), + backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .status()); } // A GTE inside of a fusion node inherits the layout of its operand (which @@ -692,7 +691,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = tools::Parse(module_str).ValueOrDie(); + auto module = ParseHloString(module_str).ValueOrDie(); ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( @@ -808,7 +807,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - LayoutAssignment layout_assignment(computation_layout); + LayoutAssignment layout_assignment(&computation_layout); Status error_status = layout_assignment.Run(module.get()).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc deleted file mode 100644 index 79dfd1e409f1556a50e9ba6c845cbf9774fb1a02..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ /dev/null @@ -1,371 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/liveness_util.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { - -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { - // GetTupleElement instructions only access the top-level buffer of their - // operand. - return true; - } else if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - auto it = std::find_if( - user->fused_parameters().begin(), user->fused_parameters().end(), - [=](HloInstruction* fused_param) { - return user->operand(fused_param->parameter_number()) == operand; - }); - CHECK(it != user->fused_parameters().end()); - // Iterate through all users of all buffer aliases of the buffer in the - // points-to set of fusion parameter at 'index'. - // Return false if any uses are detected at 'index', returns true otherwise. - const LogicalBuffer* buffer = - points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. - return false; - } - } - // Return true: found no uses of 'operand' at 'index' in 'user'. - return true; - } - return false; -} - -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const HloDataflowAnalysis& dataflow) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - // Iterate through all users of all uses of the fusion parameter value. - // Return false if any uses are detected, returns true otherwise. - const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index); - return value.uses().empty(); - } else { - // Return false if no value at 'operand' and 'index' is used at 'user'. - for (const HloValue* value : - dataflow.GetValueSet(operand, index).values()) { - for (const HloUse& use : value->uses()) { - if (use.instruction == user) { - return false; - } - } - } - } - - return true; -} - -namespace { - -// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. -// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) -// where 'user' is a user of an alias of 'instruction' at 'index', and -// 'operand_index' is the operand index at which the alias appears in the -// operand list of 'user'. -std::vector> GetAllUsesOfInstructionAtIndex( - HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis) { - std::vector> uses; - const PointsToSet::BufferList& points_to = - points_to_analysis.GetPointsToSet(instruction).element(index); - for (const LogicalBuffer* buffer : points_to) { - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { - uses.emplace_back(alias_user, op_idx); - } - } - } - } - return uses; -} - -// Returns true if there is exactly one use of 'operand' at 'operand_index' -// in 'fusion.fused_instructions', where the singleton use is the fused -// root at operand index 'use_operand_index'. Returns false otherwise. -// -// REQUIRES: 'fusion' opcode is a kFusion instruction. -bool HasUniqueFusedUseOfOperandAt( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* fusion, const int64 use_operand_index, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); - // Check that 'operand' is unique in the operand list of 'fusion'. - if (fusion->OperandIndices(operand).size() > 1) { - return false; - } - // Find fusion parameter associated with 'operand'. - const auto& fused_params = fusion->fused_parameters(); - auto fused_param_it = std::find_if( - fused_params.begin(), fused_params.end(), - [&](HloInstruction* fused_param) { - return fusion->operand(fused_param->parameter_number()) == operand; - }); - if (fused_param_it == fused_params.end()) { - return false; - } - auto* fused_param = *fused_param_it; - // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. - auto fused_param_uses = GetAllUsesOfInstructionAtIndex( - fused_param, operand_index, points_to_analysis); - // Return true iff there is exactly one use of 'operand' at 'index', and - // this singleton use is the fused root (at index in 'use_operand_indices'). - return fused_param_uses.size() == 1 && - fused_param_uses[0].first == fusion->fused_expression_root() && - fused_param_uses[0].second == use_operand_index; -} - -} // namespace - -// User and operand can share buffers iff both instructions emit the same shape -// and layout, and 'user' meets one of the following qualifications: -// -// (1) Is element-wise. Or... -// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' -// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. Or... -// (3) Is a kDot -> kAdd output fusion instruction where the only use of -// 'operand' at 'index' in the set 'user.fused_instructions' is a kAdd fused -// root at operand 0 or 1. Or... -// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index -// 0. -// -// (2) and (3) can only be determined if points-to analysis is available. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - const Shape& operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - const Shape& user_subshape = - ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, - points_to_analysis); - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { - // Output fusion with kAdd fused root. - - // Check if one operand of kAdd fused root is kDot or kConvolution. - auto* add = user->fused_expression_root(); - auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot; - }); - if (add_operand_it == add->operands().end()) { - return false; - } - auto* matched_add_operand = *add_operand_it; - // Calculate operand index of 'add' operand which was not matched above. - const int64 other_add_operand_index = - matched_add_operand == add->operand(0) ? 1 : 0; - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root (at operand - // index 'other_add_operand_index'). - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, - other_add_operand_index, - points_to_analysis); - } - } - if (user->opcode() == HloOpcode::kDynamicUpdateSlice || - user->opcode() == HloOpcode::kWhile) { - // We eliminated other users in BufferLiveness::live_range_strictly_before, - // so here we just need to check that the use is at operand index 0. - std::vector operand_indices = user->OperandIndices(operand); - return operand_indices.size() == 1 && operand_indices[0] == 0; - } - if (user->opcode() == HloOpcode::kCall) { - // TODO(b/62548313): Remove when buffer assignment is module scoped and - // does not assign buffers to calls. - // Find called computation parameter associated with 'operand'. - const std::vector operand_indices = user->OperandIndices(operand); - if (operand_indices.size() > 1) { - return false; - } - CHECK_EQ(1, operand_indices.size()); - auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); - // Get all uses of 'operand' at 'index' in called computation. - auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index, - points_to_analysis); - - // Return true iff: - // *) There exists exactly one use of 'operand' in called computation. - // *) The unique use is by the root instruction of called computation. - // (Note: we check the root of the called computation, because the - // root result buffer is required to alias with the Call result buffer). - // *) The root instruction of the called computation is element-wise on - // 'operand'. - auto* callee_root = user->to_apply()->root_instruction(); - return param_uses.size() == 1 && param_uses[0].first == callee_root && - callee_root->IsElementwiseOnOperand(param_uses[0].second); - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -bool CanShareOperandBufferWithUser(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* user, - const ShapeIndex& user_index, - const HloDataflowAnalysis& dataflow) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - const Shape& operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - const Shape& user_subshape = - ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - - if (user->opcode() == HloOpcode::kFusion) { - // Get the parameter associated with 'operand'; - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - - const HloValue& value = - dataflow.GetValueDefinedAt(fusion_param, operand_index); - if (value.uses().size() != 1) { - return false; - } - const HloUse& use = value.uses()[0]; - - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { - // Output fusion with kAdd fused root. - - // Check if one operand of kAdd fused root is kDot, or kConvolution. - auto* add = user->fused_expression_root(); - auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot; - }); - if (add_operand_it == add->operands().end()) { - return false; - } - auto* matched_add_operand = *add_operand_it; - // Calculate operand index of 'add' operand which was not matched above. - const int64 other_add_operand_index = - matched_add_operand == add->operand(0) ? 1 : 0; - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root (at operand - // index 'other_add_operand_index'). - return use.instruction == user->fused_expression_root() && - use.operand_number == other_add_operand_index; - } - } - if (user->opcode() == HloOpcode::kDynamicUpdateSlice || - user->opcode() == HloOpcode::kWhile) { - // We eliminated other users in BufferLiveness::live_range_strictly_before, - // so here we just need to check that the use is at operand index 0. - std::vector operand_indices = user->OperandIndices(operand); - return operand_indices.size() == 1 && operand_indices[0] == 0; - } - if (user->opcode() == HloOpcode::kCall) { - // Get all uses of value defined by 'operand' at 'operand_index'. - const auto& uses = - dataflow.GetValueDefinedAt(operand, operand_index).uses(); - // Return true iff: - // *) There exists two uses of 'operand'. - // *) One use is by 'user' (caller). - // *) One use is by root instruction of called computation (callee root). - // (Note: we check the root of the called computation, because the - // root result buffer is required to alias with the Call result buffer). - // *) The root instruction of the called computation is element-wise on - // 'operand'. - const bool found_caller_use = - std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { - return use.instruction == user; - }) != uses.end(); - auto* callee_root = user->to_apply()->root_instruction(); - const bool found_elementwise_callee_use = - std::find_if( - uses.begin(), uses.end(), [callee_root](const HloUse& use) { - return use.instruction == callee_root && - callee_root->IsElementwiseOnOperand(use.operand_number); - }) != uses.end(); - return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h deleted file mode 100644 index 28ef991880039de73cc158a67ef2a5f78fc90e6d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// A collection of utilities on the HLO graph. - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ - -#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" - -namespace xla { - -// Returns true if 'user' cannot possibly use the buffer at 'index' in -// 'operand'. Returns false otherwise. -// -// REQUIRES: 'operand' is an operand of 'user'. -// -// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have -// moved over to the dataflow overload. -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis); -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const HloDataflowAnalysis& dataflow); - -// Returns true if 'user' (at 'user_index') can share a buffer with its operand -// 'operand' (at 'operand_index'). Returns false otherwise. -// -// REQUIRES: 'operand' is an operand of 'user'. -// -// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have -// moved over to the dataflow overload. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis); -bool CanShareOperandBufferWithUser(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* user, - const ShapeIndex& user_index, - const HloDataflowAnalysis& dataflow); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc deleted file mode 100644 index c01b52df62ee67eb2c6249bfa0baf8366dd3c331..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ /dev/null @@ -1,421 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/liveness_util.h" - -#include - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace xla { -namespace { - -class PointsToAnalysisTestBase : public HloTestBase { - protected: - void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); - computation_ = module_->AddEntryComputation(std::move(computation)); - } - - void RunAnalysis() { - CHECK_NOTNULL(module_.get()); - points_to_analysis_ = - TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); - } - - void BuildModuleAndRunAnalysis(std::unique_ptr computation) { - BuildModule(std::move(computation)); - RunAnalysis(); - } - - std::unique_ptr module_; - HloComputation* computation_ = nullptr; - std::unique_ptr points_to_analysis_; - std::unique_ptr dataflow_analysis_; -}; - -class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; - -TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { - auto builder = HloComputation::Builder(TestName()); - - Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); - builder.AddInstruction( - HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); - - BuildModuleAndRunAnalysis(builder.Build()); - - // GetTupleElement instructions only access the top-level buffer of their - // operand. - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_)); - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_)); - - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *dataflow_analysis_)); - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *dataflow_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *dataflow_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *dataflow_analysis_)); -} - -TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); - - // Create a DynamicUpdateSlice instruction of tuple element 1. - auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); - auto dynamic_update_slice = - builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); - builder.AddInstruction( - HloInstruction::CreateTuple({gte0, dynamic_update_slice})); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {dynamic_update_slice, starts, update, gte1}, - HloInstruction::FusionKind::kLoop); - RunAnalysis(); - - // The fusion instruction never uses tuple element 0, but does use element 1. - EXPECT_TRUE( - DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_)); - EXPECT_FALSE( - DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_)); - - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, fusion, *dataflow_analysis_)); - EXPECT_FALSE( - DoesNotUseOperandBuffer(tuple, {1}, fusion, *dataflow_analysis_)); -} - -class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; - -TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { - auto builder = HloComputation::Builder(TestName()); - - Shape shape = ShapeUtil::MakeShape(F32, {8}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); - auto log = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { - auto builder = HloComputation::Builder(TestName()); - - Shape in_shape = ShapeUtil::MakeShape(F32, {8}); - Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); - auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, in_shape, "param0")); - auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, in_shape, "param1")); - auto result = builder.AddInstruction( - HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *points_to_analysis_)); - EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *dataflow_analysis_)); - EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { - auto builder = HloComputation::Builder(TestName()); - - Shape shape = ShapeUtil::MakeShape(F32, {8}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); - - // Create a DynamicUpdateSlice instruction of tuple element 1. - auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); - auto dynamic_update_slice = - builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); - builder.AddInstruction( - HloInstruction::CreateTuple({gte0, dynamic_update_slice})); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {dynamic_update_slice, starts, update, gte1}, - HloInstruction::FusionKind::kLoop); - RunAnalysis(); - - // The fusion instruction can share with tuple element 1. - EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *points_to_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *dataflow_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - Shape update_shape = ShapeUtil::MakeShape(F32, {4}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - auto update = builder.AddInstruction( - HloInstruction::CreateParameter(1, update_shape, "update")); - auto starts = builder.AddInstruction( - HloInstruction::CreateParameter(2, starts_shape, "starts")); - auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); - - BuildModuleAndRunAnalysis(builder.Build()); - - // The DynamicUpdateSlice instruction can share with the data operand, but not - // with update or starts. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *dataflow_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *dataflow_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); - auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape, HloOpcode::kAdd, dot, add_operand)); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {add, dot}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused dot add should be able to share buffer with 'add_operand'. - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto reverse = builder.AddInstruction( - HloInstruction::CreateReverse(data_shape, operand, {0, 1})); - - auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {add, two, reverse}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused operand->reverse->add cannot alias operand buffer 'operand'. - EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - - auto make_cond = [this, &data_shape]() { - auto builder = HloComputation::Builder(TestName() + ".Cond"); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); - return builder.Build(); - }; - - auto make_body = [this, &data_shape]() { - auto builder = HloComputation::Builder(TestName() + ".Body"); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); - return builder.Build(); - }; - - module_ = CreateNewModule(); - HloComputation* cond_computation = - module_->AddEmbeddedComputation(make_cond()); - HloComputation* body_computation = - module_->AddEmbeddedComputation(make_body()); - - auto builder = HloComputation::Builder(TestName()); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - auto whil = builder.AddInstruction(HloInstruction::CreateWhile( - data_shape, cond_computation, body_computation, data)); - computation_ = module_->AddEntryComputation(builder.Build()); - - RunAnalysis(); - - // The While instruction can share with the data operand. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_)); -} - -// Tests that Call can alias operand buffer if the only use of the operand -// in the called computation is an elementwise instruction. -TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { - Shape shape = ShapeUtil::MakeShape(F32, {8}); - // Build sub-computation with fusion root. - auto sub_builder = HloComputation::Builder(TestName() + "_sub"); - auto sub_param = sub_builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "sub_param")); - auto one = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto ones = sub_builder.AddInstruction( - HloInstruction::CreateBroadcast(shape, one, {1})); - auto add = sub_builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - - module_ = CreateNewModule(); - auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); - sub_computation->CreateFusionInstruction({add, ones}, - HloInstruction::FusionKind::kLoop); - - // Build entry-computation with kCall which calls 'sub_computation'. - auto builder = HloComputation::Builder(TestName()); - - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto reverse = - builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); - auto call = builder.AddInstruction( - HloInstruction::CreateCall(shape, {reverse}, sub_computation)); - computation_ = module_->AddEntryComputation(builder.Build()); - - RunAnalysis(); - - EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, - *points_to_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, - *dataflow_analysis_)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index bc683a1880b010d57e83aa6e9ffa95fda299e1a0..d909845a3a21fc55e44b0037371fca30e577980f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -80,8 +80,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, /*Name=*/""); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); generators_[constant] = [=](const IrArray::Index& index) { - return IrArray(global, constant->shape()) + return IrArray(shape_constant, constant->shape()) .EmitReadArrayElement(index, ir_builder_); }; @@ -151,7 +153,7 @@ Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { Status FusedIrEmitter::FinishVisit(HloInstruction* root) { fused_root_ = root; - return tensorflow::Status::OK(); + return Status::OK(); } FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 23d2d4e87d26f4988ebddcf20f5a27af6a7fe0d6..1f6e3c829f890d68aa251b101f0402c120a19d61 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -15,53 +15,57 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, - const std::function& for_body_generator) { - If(ir_builder_->CreateICmpSLT(start, end), [&]() { - for_body_generator(start, /*is_first_iteration=*/true); - For(name, ir_builder_->CreateAdd(start, step), end, step, - [&](llvm::Value* iv) { for_body_generator(iv, false); }); + const std::function& for_body_generator) { + return If(ir_builder_->CreateICmpSLT(start, end), [&]() -> Status { + TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); + return For(name, ir_builder_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, - const std::function& for_body_generator) { + const std::function& + for_body_generator) { if (peel_first_iteration) { - For(name, start, end, step, true, - [&](llvm::Value* indvar, bool is_first_iteration) { - for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration)); - }); + return For(name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) -> Status { + return for_body_generator( + indvar, ir_builder_->getInt1(is_first_iteration)); + }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( name, start, end, step, ir_builder_, - /*prevent_unrolling=*/prevent_unrolling_, + /*unroll_mode=*/unroll_mode_, /*prevent_vectorization=*/prevent_vectorization_); ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); - for_body_generator(loop->GetIndVarValue(), - /*is_first_iteration=*/ir_builder_->CreateICmpEQ( - loop->GetIndVarValue(), start)); + TF_RETURN_IF_ERROR( + for_body_generator(loop->GetIndVarValue(), + /*is_first_iteration=*/ir_builder_->CreateICmpEQ( + loop->GetIndVarValue(), start))); llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_); + return Status::OK(); } } -void KernelSupportLibrary::If( - llvm::Value* condition, const std::function& true_block_generator, - const std::function& false_block_generator) { +Status KernelSupportLibrary::If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, "", ir_builder_); ir_builder_->SetInsertPoint(&if_data.true_block->back()); - true_block_generator(); + TF_RETURN_IF_ERROR(true_block_generator()); ir_builder_->SetInsertPoint(&if_data.false_block->back()); - false_block_generator(); + TF_RETURN_IF_ERROR(false_block_generator()); llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); + return Status::OK(); } void KernelSupportLibrary::EmitAndCallOutlinedKernel( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 1c00b2aabd182da72e78d2c9c01cbe70cfd8e33c..e17c649e5272a9e7c0d5126083ab76542abfdf48 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -30,13 +31,14 @@ namespace xla { class KernelSupportLibrary { public: // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. - // If `prevent_unrolling` is true then unrolling is explicitly disabled on - // every loop generated by this instance of KernelSupportLibrary. - explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = true, - bool prevent_vectorization = true) + // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop + // generated by this instance of KernelSupportLibrary. + explicit KernelSupportLibrary( + llvm::IRBuilder<>* ir_builder, + llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll, + bool prevent_vectorization = true) : ir_builder_(ir_builder), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} // Generates the following control flow structure: @@ -46,19 +48,41 @@ class KernelSupportLibrary { // for (i64 i = `start` + `step`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } - void For( + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator); + + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& - for_body_generator); + for_body_generator) { + CHECK_EQ(Status::OK(), + For(name, start, end, step, + [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& + for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } - void For( + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -75,37 +99,101 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - llvm::Value* step, bool peel_first_iteration, - const std::function& + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); + + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, llvm::Value* step, + bool peel_first_iteration, + const std::function& + for_body_generator) { + TF_CHECK_OK(For( + name, start, end, step, peel_first_iteration, + [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + return For(name, /*start=*/start, /*end=*/end, + /*step=*/ir_builder_->getInt64(step), peel_first_iteration, for_body_generator); + } - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - For(name, /*start=*/start, /*end=*/end, - /*step=*/ir_builder_->getInt64(step), peel_first_iteration, - for_body_generator); + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + ForReturnVoid(name, /*start=*/start, /*end=*/end, + /*step=*/ir_builder_->getInt64(step), peel_first_iteration, + for_body_generator); } - void For( + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + return For(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); + } + + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - For(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + ForReturnVoid(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { + return for_body_generator(indvar); + }); + } + + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function& for_body_generator) { + return For(name, start, end, ir_builder_->getInt64(step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void For( + void ForReturnVoid( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function& for_body_generator) { + ForReturnVoid(name, start, end, ir_builder_->getInt64(step), + for_body_generator); + } + + Status For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } + + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -114,9 +202,25 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - void If(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() {}); + Status If(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = + []() -> Status { return Status::OK(); }); + + void IfReturnVoid(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() { + }) { + TF_CHECK_OK(If(condition, + [&]() { + true_block_generator(); + return Status::OK(); + }, + [&]() { + false_block_generator(); + return Status::OK(); + })); + } using ArgumentVector = tensorflow::gtl::ArraySlice; @@ -174,7 +278,7 @@ class KernelSupportLibrary { private: llvm::IRBuilder<>* ir_builder_; - bool prevent_unrolling_; + llvm_ir::UnrollMode unroll_mode_; bool prevent_vectorization_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 497b48ff227d7d1f158080529372df44b6932b24..9f867014fb015845448c4fcf9c165750f8a61935 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -34,7 +34,7 @@ namespace llvm_ir { ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* step, bool prevent_unrolling, + llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) : prefix_(std::string(prefix)), suffix_(std::string(suffix)), @@ -42,15 +42,15 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, end_index_(end_index), step_(step), insert_before_bb_(nullptr), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling, bool prevent_vectorization) { + UnrollMode unroll_mode, bool prevent_vectorization) { std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, - end_index, step, prevent_unrolling, + end_index, step, unroll_mode, prevent_vectorization)); loop->Emit(ir_builder); return loop; @@ -147,11 +147,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { std::vector ForLoop::GetLoopMetadata( llvm::IRBuilder<>* ir_builder) { const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; + const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full"; const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; llvm::LLVMContext* ctx = &start_index_->getContext(); std::vector result; - if (prevent_unrolling_) { + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) { result.push_back(llvm::MDNode::get( *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)})); } @@ -162,6 +163,10 @@ std::vector ForLoop::GetLoopMetadata( llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); } + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)})); + } return result; } @@ -178,25 +183,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), - prevent_unrolling, prevent_vectorization); + unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr loop(new ForLoop( - /*prefix=*/name_, suffix, start_index, end_index, stride, - prevent_unrolling, prevent_vectorization)); + /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode, + prevent_vectorization)); loop->Emit(ir_builder_); if (outer_loop_preheader_bb_ == nullptr) { @@ -215,23 +220,23 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), prevent_unrolling, + ir_builder_->getInt64(end_index), unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), prevent_unrolling, + ir_builder_->getInt64(stride), unroll_mode, prevent_vectorization); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index d915f95db134918a173a9711936bb1e2f1ea0d95..4e403cd994874c27453574283c5c573c876628db 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -34,6 +34,12 @@ limitations under the License. namespace xla { namespace llvm_ir { +enum class UnrollMode { + kDefaultUnroll, + kFullyUnroll, + kNoUnroll, +}; + // A class for constructing a for-loop in LLVM IR. class ForLoop { public: @@ -69,12 +75,13 @@ class ForLoop { // LLVM IR. If non-empty, it is prepended to the name of the induction // variable value and each basic block created for the loop. // - // If `prevent_unrolling` is true then emit metadata that directs LLVM to not - // unroll the generated loop. + // `unroll_mode` specifies the desired LLVM unrolling behavior for generated + // loop. static std::unique_ptr EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = false, bool prevent_vectorization = false); + UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // The names of the blocks follow LLVM's conventions. Control flow amongst the // blocks for the example C code looks like: @@ -128,7 +135,7 @@ class ForLoop { ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, - bool prevent_unrolling, bool prevent_vectorization); + UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* ir_builder); @@ -161,7 +168,7 @@ class ForLoop { llvm::BasicBlock* body_bb_; llvm::BasicBlock* exit_bb_; llvm::Value* indvar_; - bool prevent_unrolling_; + UnrollMode unroll_mode_; bool prevent_vectorization_; TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); @@ -182,34 +189,34 @@ class ForLoopNest { // Adds a loop to the nest. If no loop has been added yet then emit a loop at // the current insert point of the given builder. If one or more loops have - // been added then emit loop inside the body of the last added loop. If - // prevent_unrolling is true, then metadata is emitting directing LLVM to not - // unroll this loop. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + // been added then emit loop inside the body of the last added loop. + // unroll_mode is used to emit metadata that controls LLVM unrolling. + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* stride, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, int64 stride, + tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ec04239b4f9112134ba876fdfbb3905a3baf1f72..ff64da87e9c9acf8a9d7ff87d3b1be7a9e9106bb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -87,18 +87,10 @@ llvm::Value* EmitCallToIntrinsic( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice overloaded_types, llvm::IRBuilder<>* ir_builder) { - std::vector types; - for (auto type : overloaded_types) { - types.push_back(type); - } llvm::Module* module = ModuleFromIRBuilder(ir_builder); - llvm::Function* intrinsic = - llvm::Intrinsic::getDeclaration(module, intrinsic_id, types); - std::vector operands_vec; - for (auto operand : operands) { - operands_vec.push_back(operand); - } - return ir_builder->CreateCall(intrinsic, operands_vec); + llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( + module, intrinsic_id, AsArrayRef(overloaded_types)); + return ir_builder->CreateCall(intrinsic, AsArrayRef(operands)); } llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, @@ -368,15 +360,52 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, return llvm::ConstantArray::get(aggregate_type, elements); } +template +llvm::Constant* GetConstantDataArray(const Literal& literal, + llvm::Module* module) { + const T* data = static_cast(literal.untyped_data()); + int64 num_elements = literal.size_bytes() / sizeof(T); + return llvm::ConstantDataArray::get(module->getContext(), + llvm::makeArrayRef(data, num_elements)); +} + } // namespace llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { - std::vector multi_index(ShapeUtil::Rank(literal.shape()), 0); - llvm::Constant* value = LiteralToConstant( - literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, - &multi_index, module); - return value; + const Shape& shape = literal.shape(); + // TODO(b/29904935): We can get rid of this switch by exposing a + // ConstantDataArray factory method that takes a llvm::Type and a StringRef. + switch (shape.element_type()) { + case U64: + return GetConstantDataArray(literal, module); + case U32: + return GetConstantDataArray(literal, module); + case U8: + return GetConstantDataArray(literal, module); + case S64: + return GetConstantDataArray(literal, module); + case S32: + return GetConstantDataArray(literal, module); + case F64: + return GetConstantDataArray(literal, module); + case F32: + return GetConstantDataArray(literal, module); + case BF16: + case F16: + return GetConstantDataArray(literal, module); + case PRED: + return GetConstantDataArray(literal, module); + // TODO(b/29904935): Also use ConstantDataArray for complex numbers. + case C64: { + int64 dimensions = ShapeUtil::Rank(shape); + std::vector multi_index(dimensions, 0); + return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1, + &multi_index, module); + } + default: + LOG(FATAL) << "unsupported type " << shape.element_type(); + } } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 3978acc132f34b8b195d3772ccf71d0d467984db..dc2934a34c23f8229947210cacc9863d47c2ea55 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -39,14 +39,13 @@ LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder) - : body_emitter_([=](const llvm_ir::IrArray::Index array_index) - -> ::tensorflow::Status { + : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status { // Convert target_element_generator to a BodyEmitter. TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); target_array.EmitWriteArrayElement(array_index, target_element, ir_builder); - return tensorflow::Status::OK(); + return Status::OK(); }), shape_(target_array.GetShape()), ir_builder_(ir_builder) {} @@ -84,7 +83,9 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, // Sanity check: In multi-output fusion, all shapes produced must have the // same dimensions. for (const IrArray& array : target_arrays) { - CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())); + CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())) + << ": '" << shape_.ShortDebugString() << "' does not match '" + << array.GetShape().ShortDebugString() << "'"; } } @@ -124,7 +125,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { +Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { for (const IrArray::Index& array_index : EmitIndexAndSetExitBasicBlock(loop_name)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); @@ -135,7 +136,7 @@ tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { if (exit_bb_ != nullptr) { ir_builder_->SetInsertPoint(exit_bb_); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 9ff497aecd0bc964c929205c7fd410cca87d9b77..b70d28ecd3033eb26629718e50ce48f39b162273 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -38,8 +38,7 @@ using ElementGenerator = // Emits a loop for every element in the given shape. class LoopEmitter { public: - using BodyEmitter = - std::function; + using BodyEmitter = std::function; LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, llvm::IRBuilder<>* ir_builder); @@ -72,7 +71,7 @@ class LoopEmitter { tensorflow::StringPiece loop_name); // Emits a complete loop nest for every element in the given shape. - tensorflow::Status EmitLoop(tensorflow::StringPiece loop_name = ""); + Status EmitLoop(tensorflow::StringPiece loop_name = ""); protected: // An IR emitter that generates the loop body. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index 34899b7400464e4f4f97d301f35ed3b7b083bca1..dacc54742c0897bbd92315f1e33a484aae56bb7f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -49,22 +49,41 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( for (int64 i = 0; i < rank; ++i) { IrArray::Index dim_index({ir_builder->getInt64(i)}); TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); + llvm::Value* output_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), output_shape.dimensions(i)); + llvm::Value* update_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), update_shape.dimensions(i)); + + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + llvm::Value* max_bound = + ir_builder->CreateSub(output_dim_size, update_dim_size); + llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0); + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]), + zero, start_index[i]); + + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound, + start_index[i]), + max_bound, start_index[i]); } auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { // Calculate output_index, where we'll write the value from update. For // each dimension, // - // output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size. + // output_index[dim] = start_index[dim] + update_index[dim] // IrArray::Index output_index(rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* dim_size = llvm::ConstantInt::get( - update_index[i]->getType(), output_shape.dimensions(i)); - llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast( + llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast( start_index[i], update_index[i]->getType()); - output_index[i] = ir_builder->CreateURem( - ir_builder->CreateAdd(start_index0, update_index[i]), dim_size); + output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]); } // Do output[output_index] = update[update_index]. diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 3a21eda35757aa706565ee4a5286eee1acea117b..5fc08aab916e377b245b6221108956c06da70767 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -24,15 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace llvm_ir { -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, - llvm::Module* module) { +void EmitTupleSelect(const IrArray& select, const IrArray& pred, + llvm::Value* on_true, llvm::Value* on_false, + llvm::IRBuilder<>* ir_builder, llvm::Module* module) { CHECK(ShapeUtil::IsScalar(pred.GetShape())); llvm::LoadInst* pred_value = @@ -47,30 +46,27 @@ void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - std::vector element_index = {ir_builder->getInt64(0), - ir_builder->getInt64(i)}; + llvm::Value* const element_index[] = {ir_builder->getInt64(0), + ir_builder->getInt64(i)}; llvm::Value* on_true_element_address = ir_builder->CreateInBoundsGEP(on_true, element_index); llvm::Value* on_true_element = ir_builder->CreateLoad( - on_true_element_address, - tensorflow::strings::Printf("on_true_element_%d", i).c_str()); + on_true_element_address, "on_true_element_" + llvm::Twine(i)); llvm::Value* on_false_element_address = ir_builder->CreateInBoundsGEP(on_false, element_index); llvm::Value* on_false_element = ir_builder->CreateLoad( - on_false_element_address, - tensorflow::strings::Printf("on_false_element_%d", i).c_str()); + on_false_element_address, "on_false_element_" + llvm::Twine(i)); llvm::Value* output_element_address = ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); ir_builder->CreateStore( - ir_builder->CreateSelect( - pred_cond, on_true_element, on_false_element, - tensorflow::strings::Printf("select_output_element_%d", i).c_str()), + ir_builder->CreateSelect(pred_cond, on_true_element, on_false_element, + "select_output_element_" + llvm::Twine(i)), output_element_address); } } -void EmitTuple(IrArray tuple, +void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, llvm::IRBuilder<>* ir_builder, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index dbf9a140068b60505f6798360438f709bfd3feba..352d34ebf839c6c2465abade7c3d3eb3b7a34506 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -59,13 +59,13 @@ namespace llvm_ir { // of the address from the corresponding element in either // tuple_on_true or tuple_on_false: // output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, - llvm::Module* module); +void EmitTupleSelect(const IrArray& select, const IrArray& pred, + llvm::Value* on_true, llvm::Value* on_false, + llvm::IRBuilder<>* ir_builder, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. -void EmitTuple(IrArray tuple, +void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, llvm::IRBuilder<>* ir_builder, llvm::Module* module); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 0fa4061738612df76c72a18a9353f16bf6a42677..296d04d4362b12fdc39798a016ca9e8795e02586 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -24,15 +24,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -110,6 +107,11 @@ ExecutionOptions CreateExecutionOptions( ->set_xla_dump_optimized_hlo_proto_to( build_options.dump_optimized_hlo_proto_to().value()); } + if (build_options.dump_unoptimized_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_unoptimized_hlo_proto_to( + build_options.dump_unoptimized_hlo_proto_to().value()); + } if (build_options.dump_per_pass_hlo_proto_to().has_value()) { execution_options.mutable_debug_options() ->set_xla_dump_per_pass_hlo_proto_to( @@ -124,75 +126,17 @@ ExecutionOptions CreateExecutionOptions( LayoutUtil::SetToDefaultLayout( execution_options.mutable_shape_with_output_layout()); } - return execution_options; -} - -} // namespace - -StatusOr> LocalService::CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& build_options) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Validate incoming layouts. - if (argument_layouts.size() != program_shape->parameters_size()) { - return InvalidArgument( - "Invalid number of arguments for computation: expected %d, got %zu.", - program_shape->parameters_size(), argument_layouts.size()); + for (const std::string& disabled_pass : build_options.disabled_hlo_passes()) { + execution_options.mutable_debug_options()->add_xla_disable_hlo_passes( + disabled_pass); } - for (int i = 0; i < argument_layouts.size(); ++i) { - const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); - if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) { - tensorflow::gtl::optional metadata = - user_computation->ParameterMetadata(i); - auto metadata_string = [&metadata]() -> string { - if (!metadata.has_value()) { - return ""; - } - CHECK(metadata.value() != nullptr); - const OpMetadata& m = *metadata.value(); - if (!m.source_file().empty()) { - return tensorflow::strings::Printf( - " (%s:%d)", m.source_file().c_str(), m.source_line()); - } - return ""; - }; - return InvalidArgument( - "Invalid argument shape for argument %d%s, expected %s, got %s.", i, - metadata_string().c_str(), - ShapeUtil::HumanString(program_shape->parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); - } - } - if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape->result())); - } - - ExecutionOptions execution_options = - CreateExecutionOptions(build_options, program_shape.get()); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, - &execution_options, user_computation)); - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - execute_backend_->stream_executor(build_options.device_ordinal())); - - return BuildExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), executor, - build_options.device_allocator()); + return execution_options; } +} // namespace + StatusOr> LocalService::CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -260,4 +204,15 @@ StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { /*computation_count=*/1); } +StatusOr LocalService::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); + if (replica_number >= buffers.size()) { + return InvalidArgument( + "replica_number %d out of range; must be less than num_replicas = %zu.", + replica_number, buffers.size()); + } + return buffers[replica_number]; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 06567cabd6eb28aae53881613cd6beb78e25e222..39d6734c3fc06df6832cf67edddbc7c14c815cd1 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -41,23 +41,11 @@ class LocalService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // Builds an Executable with the given argument layouts and options. If - // result_layout is non-null, then the executable is compiled to produce a - // result of the given layout. If device_allocator is non-null, then the - // compiler may use it to allocate temp space on the device. The compiler is - // responsible for freeing any memory it allocates this way. - StatusOr> CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - // Builds an Executable with the given XlaComputation, argument layouts and // options. If result_layout is non-null, then the executable is compiled to // produce a result of the given layout. If device_allocator is non-null, // then the compiler may use it to allocate temp space on the device. The // compiler is responsible for freeing any memory it allocates this way. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -70,6 +58,11 @@ class LocalService : public Service { // the "easy" case where a single replica is a single device. StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + private: explicit LocalService(const ServiceOptions& options, std::unique_ptr backend); diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index 6aca6ba38572c5311797fbb91acbbcd6610a3410..f410921b4b5337192bdeae5924631d9c06b7d5a5 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -125,6 +125,12 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand. + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { // RecvDone doesn't create a new buffer but rather aliases its input (Recv) // tuple element at {0} to its output. diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index f4c63dd86b4d8a6f598d46047012e4e5bc7b3d7e..b5ef3967875a58b35631d5f69c210f5cbcd91250 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -59,6 +59,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..29f787b86b9cbb6f80d048b46b78bdad8074f488 --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -0,0 +1,342 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +StatusOr MultiOutputFusion::Run(HloModule* module) { + bool changed = false; + + for (auto* computation : module->MakeNonfusionComputations()) { + computation_ = computation; + reachability_ = computation_->ComputeReachability(); + candidates_.clear(); + candidates_index_.clear(); + all_fusion_candidates_.clear(); + + int64 index = 0; + for (auto it : computation_->MakeInstructionPostOrder()) { + candidates_.emplace_back(it); + InsertOrDie(&candidates_index_, it, index++); + } + + // Create the initial candidate list for each Node. + for (auto& node : candidates_) { + HloInstruction* instruction = node.hlo; + int64 instruction_id = get_candidate_id(instruction); + FusionCandidate& instr_node = candidates_[instruction_id]; + if (!IsFusible(instruction)) { + continue; + } + all_fusion_candidates_.push_back(instruction); + + std::vector candidates; + tensorflow::gtl::FlatSet candidates_set; + VLOG(10) << "Looking at instruction: " << instruction->name(); + for (auto operand : instruction->operands()) { + // Filter out the non-interesting instructions -- they + // will not generate the savings. + if (!IsProfitableOperand(operand)) { + VLOG(10) << "Operand not profitable: " << operand->name(); + continue; + } + VLOG(10) << "Operand profitable: " << operand->name(); + for (auto user : operand->users()) { + VLOG(10) << "User: " << user->name(); + if (user == instruction || !IsFusible(user)) { + VLOG(10) << "User is not fusible, or is the instruction itself: " + << user->name(); + continue; + } + int64 user_id = get_candidate_id(user); + if (is_connected(instruction, user)) { + VLOG(10) << "User is connected: " << user->name(); + continue; + } + if (instruction_id < user_id && + user->opcode() == HloOpcode::kFusion) { + VLOG(10) << "User ID for user: " << user->name() << " is " + << user_id << " which is higher than " << instruction_id; + continue; + } + if (!LegalToFuse(instruction, user)) { + VLOG(10) << "User not legal to fuse: " << user->name(); + continue; + } + if (candidates_set.insert(user).second) { + VLOG(10) << "User added to candidate list: " << user->name(); + candidates.push_back(user); + } + } + } + + // Iterate over candidates rather than candidates_set to avoid + // nondeterminism. + for (auto candidate : candidates) { + int64 profit = GetProfit(instruction, candidate); + if (profit > 0) { + FusionCandidate& candidate_node = + candidates_[get_candidate_id(candidate)]; + instr_node.fusibles.emplace_back(candidate, profit); + candidate_node.fusibles.emplace_back(instruction, profit); + worklist_.emplace(instruction, candidate, profit); + } + } + } + if (Perform()) { + changed = true; + } + } + return changed; +} + +HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, + HloInstruction* instr2) { + HloInstruction* remaining = instr1; + HloInstruction* fused = instr2; + // Make sure that if only one of the instructions is a fusion, or if only one + // of the instructions is a multi-output fusion, it's what will be fused into. + // + // An invariant is that no bitcast nodes will show up in the middle of a + // fusion node. This invariant must hold in order for us to lower it. Given + // that, we require that during multi-output fusion, a fusion node ending with + // bitcast to preserve its structure as a nested fusion instead being + // merged and flattened. + if (fused->opcode() == HloOpcode::kFusion && + fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + std::swap(remaining, fused); + } + if (fused->IsMultiOutputFusion()) { + std::swap(remaining, fused); + } + + if (fused->opcode() == HloOpcode::kFusion && + fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + remaining->MergeFusionInstructionIntoMultiOutput(fused); + } else { + if (remaining->opcode() == HloOpcode::kFusion && + remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) { + auto parent_computation = remaining->parent(); + // Create a nested fusion node. + auto remaining_nested_fused = + parent_computation->AddInstruction(HloInstruction::CreateFusion( + remaining->shape(), HloInstruction::FusionKind::kLoop, + remaining)); + TF_CHECK_OK(parent_computation->ReplaceInstruction( + remaining, remaining_nested_fused)); + remaining = remaining_nested_fused; + } + remaining->FuseInstructionIntoMultiOutput(fused); + } + + return remaining; +} + +void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { + HloInstruction* fusion = instr1; + HloInstruction* fused = instr2; + if (is_fused(instr1)) { + fusion = instr2; + fused = instr1; + } + + // Insert the newly created instruction (if any), to candidates_. + for (auto use : fusion->users()) { + if (candidates_index_.find(use) == candidates_index_.end()) { + int64 index = candidates_.size(); + candidates_.emplace_back(use); + InsertOrDie(&candidates_index_, use, index++); + } + } + FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)]; + FusionCandidate& fused_node = candidates_[get_candidate_id(fused)]; + + // Update the reachability graph. + UpdateReachability(fusion, fused, all_fusion_candidates_, + [this](HloInstruction* instr) { return is_fused(instr); }); + + // Update the fusible list for fusion. Variable new_fusibles keeps + // track of the new or changed entries. + std::vector> new_fusibles; + tensorflow::gtl::FlatSet in_list; + auto it = fusion_node.fusibles.begin(); + while (it != fusion_node.fusibles.end()) { + HloInstruction* instr = it->first; + if (is_fused(instr) || is_connected(fusion, instr)) { + it = fusion_node.fusibles.erase(it); + continue; + } + in_list.insert(instr); + int64 profit = GetProfit(instr, fusion); + if (profit > it->second) { + it->second = profit; + new_fusibles.emplace_back(instr, profit); + } + ++it; + } + + // Fused_node has been fused into fusion_node. Take the fusion candidates + // (fusibles) from fused_nodes and add them to the fusion_node's. Filter + // out those fusibles that no longer valid (or already in the list). + for (const auto& it : fused_node.fusibles) { + HloInstruction* instr = it.first; + if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) { + continue; + } + if (in_list.count(instr) > 0) { + continue; + } + int64 profit = GetProfit(instr, fusion); + fusion_node.fusibles.emplace_back(instr, profit); + new_fusibles.emplace_back(instr, profit); + } + fused_node.fusibles.clear(); + + // Update the worklist_. + for (auto it : new_fusibles) { + worklist_.emplace(fusion, it.first, it.second); + } +} + +bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, + HloInstruction* instr2) { + if (instr1 == instr2) { + return false; + } + if (instr1->opcode() != HloOpcode::kFusion) { + return false; + } + + // Fusing nodes with 0 user makes no sense and the rest of the implementation + // doesn't support it either. + if (instr1->user_count() == 0 || instr2->user_count() == 0) { + return false; + } + + // Check if the users of multioutput fusion is not a get-tuple-element. + // If this is the case, we bail out because the transformation assumes + // the users are get-tuple-element. + auto multioutput_user_is_not_gte = [](HloInstruction* instr) { + if (!instr->IsMultiOutputFusion()) { + return false; + } + for (auto user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return true; + } + } + return false; + }; + if (multioutput_user_is_not_gte(instr1) || + multioutput_user_is_not_gte(instr2)) { + return false; + } + + if (is_connected(instr1, instr2)) { + return false; + } + if (!ShapesCompatibleForFusion(instr1, instr2)) { + return false; + } + + return true; +} + +void MultiOutputFusion::UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip) { + for (auto instr : instrs_to_update) { + if (skip != nullptr && skip(instr)) { + continue; + } + if (reachability_->IsReachable(instr2, instr) && + reachability_->IsReachable(instr1, instr)) { + // If a candidate was already reachable by both, no update needed. + continue; + } + if (reachability_->IsReachable(instr2, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr1}, instr); + } + if (reachability_->IsReachable(instr1, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr2}, instr); + } + } +} + +bool MultiOutputFusion::Perform() { + int changed = false; + // Pick the top candidate from queue and try to merge. + while (!worklist_.empty()) { + if (fuel_ <= 0) { + VLOG(2) << "No fusing: run out of fuel."; + break; + } + ToBeFused candidate = worklist_.top(); + worklist_.pop(); + + HloInstruction* instr1 = candidate.instr1; + HloInstruction* instr2 = candidate.instr2; + + if (is_fused(instr1) || is_fused(instr2)) { + continue; + } + + VLOG(1) << "Considering candidate profit_score=" << candidate.score + << "\n\t\tinstr1 = " << instr1->ToString() + << "\n\t\tinstr2 = " << instr2->ToString(); + + if (LegalToFuse(instr1, instr2)) { + VLOG(1) << "Fuse!"; + VLOG(2) << "Before multi_output_fusion:"; + VLOG(2) << "instr1: " << instr1->ToString(); + VLOG(2) << "\n" + << instr1->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + VLOG(2) << "instr2: " << instr2->ToString(); + if (instr2->opcode() == HloOpcode::kFusion) { + VLOG(2) << "\n" + << instr2->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + } + HloInstruction* ret = Fuse(instr1, instr2); + set_is_fused(ret == instr1 ? instr2 : instr1); + Update(instr1, instr2); + changed = true; + VLOG(2) << "After fusion, \t this: " << ret->name() << "\n" + << ret->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + auto users = ret->users(); + --fuel_; + } + } + if (DoProducerConsumerMultiOutputFusion(computation_)) { + changed = true; + } + return changed; +} + +bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion( + HloComputation* /*computation*/) { + return false; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..cfdf83cfe856a7c3b05f51129446cd4e1055a8d6 --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// This class implements the fusing of sibling fusion instructions that sharing +// common operands. +// It constructs the following associated data structures. +// (1) candidates_: stores the instruction and the set of instructions it can +// fuse to. +// (2) candidates_index_: maps instruction to id. +// (3) reachability_: reachability map in this computation. +// (4) all_fusion_candidates_: the vector of candidate instructions. +// (5) worklist_: a priority queue that contains pairs of instructions to be +// fused and their fusion profit scores. +// +// Function Perform() applies the optimization. It picks up the most profitable +// pair in the worklist_, check if it's legal to fuse and fuse the pair. +// After fusion, it updates the associated structure such as reachability_, +// candidates_ and worklist_. +// Note that the reachability map is updated based on the original computation. +// This works because the reachability is monotonically increasing with +// instruction fusion. +class MultiOutputFusion : public HloPassInterface { + public: + MultiOutputFusion(int64 fuel) : fuel_(fuel) {} + + tensorflow::StringPiece name() const override { + return "multi_output_fusion"; + } + + // Run multi-output fusion on the given module. Returns whether the module + // was changed. + StatusOr Run(HloModule* module) override; + + protected: + // Main entry for the optimization. Returns true if the optimization happens. + bool Perform(); + + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + virtual bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) = 0; + + // Whether the instruction is a candidate for fusion. + virtual bool IsFusible(HloInstruction* instr) = 0; + + // This function estimates the savings by merging instr1 and instr2 into one + // multi-output fusion instruction. + virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0; + + // Whether fusing the instruction can reduce cost. + virtual bool IsProfitableOperand(HloInstruction* instr) = 0; + + // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. + virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + + // Update the reachability map after fusing instr1 and instr2. + void UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip = nullptr); + + // Hook for multi-output fusion along producer-consumer edges. + // Returns whether any instructions were fused. + // + // TODO(b/80420762): Perform producer-consumer multi-output fusion in + // InstructionFusion instead. + virtual bool DoProducerConsumerMultiOutputFusion(HloComputation* computation); + + private: + // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. + // The other instruction is removed from its parent computation. + HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); + + // Update the internal data structures after instr1 and instr2 are fused into + // one fusion instruction. + void Update(HloInstruction* instr1, HloInstruction* instr2); + + // Optimization fuel is a compiler debugging technique that makes an + // optimization pass stop what it is doing after having made N changes to the + // program, where N is the fuel. By varying N, this can be used to find the + // first single change that makes a test fail. + int64 fuel_; + + // Computation for the pass. + HloComputation* computation_; + + // An internal data structure for each instruction in current computation. + // When an instruction is removed, member 'hlo' is set to nullptr. + struct FusionCandidate { + HloInstruction* hlo; + std::list> fusibles; + explicit FusionCandidate(HloInstruction* hlo) : hlo(hlo) {} + }; + std::vector candidates_; + + // A map that maps an instruction to the index_. + tensorflow::gtl::FlatMap candidates_index_; + + // The reachability map of current computation. + std::unique_ptr reachability_; + + // This stores all the candidate instructions in current computation. + std::vector all_fusion_candidates_; + + // The pair of candidates to be fused and the profit score. + struct ToBeFused { + HloInstruction* instr1; + HloInstruction* instr2; + int64 score; + ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score) + : instr1(instr1), instr2(instr2), score(score) {} + bool operator<(const ToBeFused& rhs) const { return score < rhs.score; } + }; + std::priority_queue worklist_; + + int64 get_candidate_id(HloInstruction* instr) { + return FindOrDie(candidates_index_, instr); + } + + bool is_fused(HloInstruction* instr) { + return candidates_[get_candidate_id(instr)].hlo == nullptr; + } + + void set_is_fused(HloInstruction* instr) { + candidates_[get_candidate_id(instr)].hlo = nullptr; + } + + bool is_connected(HloInstruction* instr1, HloInstruction* instr2) { + return reachability_->IsConnected(instr1, instr2); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/owning_device_memory.cc b/tensorflow/compiler/xla/service/owning_device_memory.cc new file mode 100644 index 0000000000000000000000000000000000000000..c115bc097f3b1dd810654745b835a977955718c3 --- /dev/null +++ b/tensorflow/compiler/xla/service/owning_device_memory.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/owning_device_memory.h" + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" + +namespace xla { + +void OwningDeviceMemory::Free() { + CHECK(allocator_ != nullptr) + << "Can't call Free() on an inactive (i.e. moved from, Forget()'ten, " + "or Free()'ed) instance."; + auto status = allocator_->Deallocate(device_ordinal_, mem_); + if (!status.ok()) { + LOG(WARNING) << "Deallocating buffer " << mem_.opaque() << " failed."; + } + + allocator_ = nullptr; + mem_ = se::DeviceMemoryBase(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/owning_device_memory.h b/tensorflow/compiler/xla/service/owning_device_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..9cf071f0d9d09dfbf74b15e73caaf542714ec8d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/owning_device_memory.h @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Break circular dependency between this file and device_memory_allocator.h. +class DeviceMemoryAllocator; + +// Owning pointer for memory on a device. +// +// OwningDeviceMemory is an owning pointer like std::unique_ptr, but it can +// point to memory that resides on a "device" (e.g. a GPU). When an +// OwningDeviceMemory goes out of scope, it frees the memory it owns. +// +// We say that an instance of OwningDeviceMemory is "active" if it currently +// owns a (possibly empty) slice of memory on the device. Moving, Forget()'ing, +// Free()'ing, and other actions can deactive an active object. +// +// Note that we can't simply use stream_executor::ScopedDeviceMemory instead of +// OwningDeviceMemory, because ScopedDeviceMemory frees its pointer via a +// StreamExecutor. This class needs to free via a xla::DeviceMemoryAllocator. +class OwningDeviceMemory { + public: + OwningDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {} + + explicit OwningDeviceMemory(se::DeviceMemoryBase mem, int device_ordinal, + DeviceMemoryAllocator* allocator) + : mem_(mem), device_ordinal_(device_ordinal), allocator_(allocator) { + CHECK(allocator != nullptr) << "allocator cannot be null."; + } + + OwningDeviceMemory(OwningDeviceMemory&& other) + : mem_(other.mem_), + device_ordinal_(other.device_ordinal_), + allocator_(other.allocator_) { + other.mem_ = se::DeviceMemoryBase(); + other.allocator_ = nullptr; + } + + OwningDeviceMemory& operator=(OwningDeviceMemory&& other) { + if (allocator_ != nullptr) { + Free(); + } + mem_ = other.mem_; + device_ordinal_ = other.device_ordinal_; + allocator_ = other.allocator_; + + other.mem_ = se::DeviceMemoryBase(); + other.allocator_ = nullptr; + return *this; + } + + // Deactivates this instance if it's active. Nop if it's not active. + OwningDeviceMemory& operator=(std::nullptr_t) { + if (allocator_ != nullptr) { + Free(); + } + return *this; + } + + ~OwningDeviceMemory() { + if (allocator_ != nullptr) { + Free(); + } + } + + // The returned allocator is nonnull iff this object is active. + DeviceMemoryAllocator* allocator() const { return allocator_; } + + int device_ordinal() const { return device_ordinal_; } + + // Gets the device memory pointer. + const void* opaque() const { return mem_.opaque(); } + void* opaque() { return mem_.opaque(); } + + uint64 size() const { return mem_.size(); } + + // Determines whether this wraps a null pointer. + // + // !is_null() is sufficient but not necessary to imply `this` is active. + bool is_null() const { return mem_.is_null(); } + + se::DeviceMemoryBase AsDeviceMemoryBase() { + return se::DeviceMemoryBase(opaque(), size(), /*is_sub_buffer=*/false); + } + + // Returns the wrapped DeviceMemoryBase without freeing it, and deactivates + // this object. Precondition: `this` is active. + TF_MUST_USE_RESULT se::DeviceMemoryBase Forget() { + CHECK(allocator_ != nullptr) + << "Can't call Forget() on an inactive (i.e. moved from, Forget()'ten, " + "or Free()'ed) instance."; + allocator_ = nullptr; + se::DeviceMemoryBase mem(mem_); + mem_ = se::DeviceMemoryBase(); + return mem; + } + + // Frees the wrapped DeviceMemoryBase and deactivates this object. + // Precondition: `this` is active. + void Free(); + + private: + se::DeviceMemoryBase mem_; + int device_ordinal_; + DeviceMemoryAllocator* allocator_; // Null if this object is inactive. +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index d3bc47e61e0e75fa2ef181988700f88cec9c1d76..2515222cf2db3d9699c85c13f4fe72b3488fa217 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -204,7 +204,7 @@ class LayoutPattern { // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. constexpr LayoutPattern> EqualTo( - const Layout* layout) const { + const ::xla::Layout* layout) const { return LayoutPattern>( LayoutPatternEqualImpl(impl_, layout), matched_layout_); } diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 204e8c99209fa95adb868a676bb9e5144fed432c..fef3c132b0f3467a01b02f2be88b419459179277 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -29,7 +29,7 @@ TEST(PatternMatcherTest, AddOp) { ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); const HloInstruction* matched_inst; HloInstruction* matched_operand; @@ -182,7 +182,7 @@ TEST(PatternMatcherTest, FusionKind) { p0 = f32[] parameter(0) ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation })"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); EXPECT_TRUE(Match( diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 495f8801ba82ecbcf9f6e5db5507ef8785c752d6..d01c35b99231310692f85d0f9fbf4f2c3709d44c 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" @@ -62,37 +61,8 @@ namespace xla { namespace { -// Records the arguments used to invoke a computation in a SessionModule -// proto. -tensorflow::Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - se::StreamExecutor* executor, TransferManager* transfer_manager, - SessionModule* module) { - module->clear_arguments(); - for (const ShapedBuffer* argument : arguments) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, *argument)); - *module->add_arguments() = literal->ToProto(); - } - return tensorflow::Status::OK(); -} - -// Records the result of a computation in a SessionModule proto. -tensorflow::Status RecordResult(const ShapedBuffer& result, - se::StreamExecutor* executor, - TransferManager* transfer_manager, - SessionModule* module) { - module->clear_result(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, result)); - *module->mutable_result() = literal->ToProto(); - return tensorflow::Status::OK(); -} - // Records the arguments used to invoke a computation in an HloSnapshot proto. -tensorflow::Status RecordArguments( +Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, se::StreamExecutor* executor, TransferManager* transfer_manager, HloSnapshot* module) { @@ -103,20 +73,18 @@ tensorflow::Status RecordArguments( transfer_manager->TransferLiteralFromDevice(executor, *argument)); *module->add_arguments() = literal->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } // Records the result of a computation in a HloSnapshot proto. -tensorflow::Status RecordResult(const ShapedBuffer& result, - se::StreamExecutor* executor, - TransferManager* transfer_manager, - HloSnapshot* module) { +Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, + TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, transfer_manager->TransferLiteralFromDevice(executor, result)); *module->mutable_result() = literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace @@ -199,35 +167,20 @@ Service::Service(const ServiceOptions& options, } } -tensorflow::Status Service::Computation(const ComputationRequest* arg, - ComputationResponse* result) { - if (arg->name().empty()) { - return InvalidArgument("computation request needs a name"); - } - - *result->mutable_computation() = - computation_tracker_.NewComputation(arg->name()); - VLOG(1) << Printf("Created new computation %s on service %p, name %s", - result->computation().ShortDebugString().c_str(), this, - arg->name().c_str()); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) { +Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) { *result->mutable_channel() = channel_tracker_.NewChannel(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) { +Status Service::Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) { return allocation_tracker_.Unregister(arg->data()); } // Deconstructs a previously-allocated global handle. -tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) { +Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) { TF_ASSIGN_OR_RETURN( std::vector elements, allocation_tracker_.DeconstructTuple(arg->tuple_handle())); @@ -235,11 +188,11 @@ tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, for (auto& element : elements) { *result->add_element_handles() = element; } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const { +Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const { if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " @@ -293,8 +246,7 @@ Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation* user_computation) { + const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); ComputationLayout* host_computation_layout = config->mutable_host_entry_computation_layout(); @@ -310,17 +262,9 @@ StatusOr> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { - if (user_computation == nullptr) { - return InvalidArgument( - "Argument does not match shape of computation parameter %d: want " - "%s, got %s", - i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*argument_shapes[i]).c_str()); - } - return InvalidParameterArgument( - *user_computation->ParameterMetadata(i).value(), - "Argument does not match shape of computation parameter %d: want %s, " - "got %s", + return InvalidArgument( + "Argument does not match shape of computation parameter %d: want " + "%s, got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } @@ -345,6 +289,9 @@ StatusOr> Service::CreateModuleConfig( // If the result layout is not set, then choose the default. // TODO(b/29118294): Allow the compiler to choose a better layout in this // case. + // TODO(b/78356948): We are forcing the default layout here. We should fix + // clients which expect a default layout, to be explicit about it, by + // passing the proper ExecutionOptions with shape_with_output_layout set. host_computation_layout->mutable_result_layout()->SetToDefaultLayout(); device_computation_layout->mutable_result_layout()->SetToDefaultLayout(); } @@ -368,76 +315,12 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, - const UserComputation* user_computation) { + const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); } - return CreateModuleConfig(program_shape, argument_shapes, &execution_options, - user_computation); -} - -StatusOr>> Service::BuildExecutables( - std::vector versioned_handles, - std::vector> module_configs, - Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p", this); - - // Dump computation proto state if flag is set. - std::vector> session_modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { - const string& directory_path = - module_configs[i]->debug_options().xla_dump_computations_to(); - const string& other_directory_path = - module_configs[i]->debug_options().xla_dump_executions_to(); - if (directory_path.empty() && other_directory_path.empty()) { - continue; - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handles[i].handle)); - if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handles[i].handle.handle(), - session_module->entry().name().c_str(), - versioned_handles[i].version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - session_modules.push_back(std::move(session_module)); - } - } - - VLOG(1) << "Computation handles:"; - for (const VersionedComputationHandle& versioned_handle : versioned_handles) { - VLOG(1) << versioned_handle; - } - - CHECK_EQ(versioned_handles.size(), module_configs.size()); - std::vector> modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { - const VersionedComputationHandle& versioned_handle = versioned_handles[i]; - const HloModuleConfig& config = *module_configs[i]; - TF_ASSIGN_OR_RETURN(auto module, - computation_tracker_.BuildHloModule( - versioned_handle, config, - /*include_unreachable_instructions=*/true)); - modules.push_back(std::move(module)); - } - - TF_ASSIGN_OR_RETURN( - std::vector> executables, - backend->compiler()->Compile(std::move(modules), std::move(executors), - device_allocator)); - - for (size_t i = 0; i < versioned_handles.size(); ++i) { - if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { - executables[i]->set_session_module(std::move(session_modules[i])); - } - } - - return std::move(executables); + return CreateModuleConfig(program_shape, argument_shapes, &execution_options); } StatusOr>> Service::BuildExecutables( @@ -511,99 +394,7 @@ Status Service::ValidateEntryComputationLayout(HloModule* module) { module->device_entry_computation_layout().result_shape(), execute_backend_->transfer_manager()->HostShapeToDeviceShape( module->host_entry_computation_layout().result_shape()))); - return tensorflow::Status::OK(); -} - -StatusOr> Service::BuildExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, - versioned_handle.ToString().c_str()); - - // Dump computation proto state if flag is set. - std::unique_ptr session_module; - const string& directory_path = - module_config->debug_options().xla_dump_computations_to(); - const string& other_directory_path = - module_config->debug_options().xla_dump_executions_to(); - if (!directory_path.empty() || !other_directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handle.handle.handle(), - session_module->entry().name().c_str(), - versioned_handle.version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - } - } - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - true)); - - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); - - TF_ASSIGN_OR_RETURN( - module, backend->compiler()->RunHloPasses(std::move(module), executor, - device_allocator)); - // Check that on-host and on-device shapes are consistent. - TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); - - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - backend->compiler()->RunBackend( - std::move(module), executor, device_allocator)); - - if (!other_directory_path.empty()) { - executable->set_session_module(std::move(session_module)); - } - - return std::move(executable); -} - -StatusOr> Service::BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator) { - std::shared_ptr executable = - compilation_cache_.LookUp(versioned_handle, *module_config); - - if (executable != nullptr) { - // Executable found in the computation cache. - if (profile != nullptr) { - profile->set_compilation_cache_hit(true); - } - return executable; - } - - uint64 start_micros = - // Avoid reading the clock if we don't want timing info - (profile != nullptr) ? tensorflow::Env::Default()->NowMicros() : 0; - - // Take a copy of the module config, as compilation introduces layouts where - // layouts were optional before. - HloModuleConfig original_module_config = *module_config; - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable_unique_ptr, - BuildExecutable(versioned_handle, std::move(module_config), backend, - executor, device_allocator)); - - if (profile != nullptr) { - uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - uint64 milliseconds = (end_micros - start_micros) / 1000; - profile->set_compilation_cache_hit(false); - profile->set_compile_time_ms(milliseconds); - } - - // Insert executable into the cache. - return compilation_cache_.Insert(std::move(executable_unique_ptr), - original_module_config); + return Status::OK(); } StatusOr> @@ -626,9 +417,16 @@ Service::ExecuteParallelAndRegisterResult( // profiled. std::map index_to_profiled_streams; - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - backend->computation_placer()->AssignDevices( - options_.number_of_replicas(), executables.size())); + // Build DeviceAssignment for all cores based on the provided device handles. + DeviceAssignment device_assignment(options_.number_of_replicas(), + executables.size()); + for (int64 i = 0; i < executables.size(); i++) { + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + CHECK_EQ(replicas.size(), arguments[i].size()); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + device_assignment(replica, i) = replicas[replica]->device_ordinal(); + } + } for (int64 i = 0; i < executables.size(); i++) { // Stream executors for the replicas of the current computation. @@ -801,13 +599,6 @@ StatusOr Service::ExecuteAndRegisterResult( result_tag); } -tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - return computation->SetReturnValue(arg->operand()); -} - StatusOr> Service::GetExecutors( const ExecutionOptions& execution_options, int64 requests_size, int64 request_index) const { @@ -849,119 +640,8 @@ StatusOr>> Service::GetArguments( return replicated_arguments; } -tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) { - VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - - std::vector>> all_arguments; - std::vector> all_executors; - std::vector versioned_handles; - std::vector> module_configs; - std::vector computation_names; - std::vector device_handles; - - int num_requested_devices = - std::accumulate(arg->requests().begin(), arg->requests().end(), 0, - [](int a, const ExecuteRequest& r) -> int { - return a + r.execution_options().device_handles_size(); - }); - if (num_requested_devices * options_.number_of_replicas() > - execute_backend_->device_count()) { - return FailedPrecondition( - "there are not enough stream executors to execute %d computations", - num_requested_devices); - } - - for (int64 i = 0; i < arg->requests_size(); ++i) { - // Get the stream executor for the i'th computation. This stream executor - // is one of the executors to run the replicated computation. - const ExecutionOptions& execution_options = - arg->requests(i).execution_options(); - - // Get the executors. - TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, - arg->requests_size(), i)); - - // Resolve the UserComputation object associated with the requested - // computation and compute the program shape. - const ExecuteRequest& request = arg->requests(i); - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(request.computation())); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Get the replicated arguments. - TF_ASSIGN_OR_RETURN(auto replicated_arguments, - GetArguments(execution_options, request.arguments())); - - // Create an HloModuleConfig object for the computation, given the shape of - // the program and the argument allocations. Here, we care only about the - // shapes of the arguments, so, it is sufficient to use the arguments of - // replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - request.execution_options(), user_computation)); - VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - // Adds to the vectors to build and execute the computations after the loop. - all_arguments.push_back(replicated_arguments); - all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); - versioned_handles.push_back(versioned_handle); - module_configs.push_back(std::move(module_config)); - computation_names.insert(computation_names.end(), executors.size(), - user_computation->name()); - all_executors.push_back(executors); - device_handles.insert(device_handles.end(), - execution_options.device_handles().begin(), - execution_options.device_handles().end()); - } - - // Build the user computations into HloModules and compile to generate the - // executables. - // - // TODO(jlebar): There's currently no way to pass a device allocator to - // ExecuteParallel, so we have to pass a null device_allocator below. - TF_ASSIGN_OR_RETURN( - std::vector> executables, - BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), all_executors, - /*device_allocator=*/nullptr)); - std::vector executable_ptrs; - executable_ptrs.reserve(executables.size()); - for (const auto& executable : executables) { - executable_ptrs.push_back(executable.get()); - } - - // Execute the generated executables in parallel and return the device - // handles for each computation's output. - ExecutionProfile profile; - TF_ASSIGN_OR_RETURN( - std::vector outputs, - ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), device_handles, - computation_names, &profile)); - for (const GlobalDataHandle& output : outputs) { - ExecuteResponse response; - *response.mutable_output() = output; - *response.mutable_profile() = profile; - *result->add_responses() = response; - } - - VLOG(1) << "successfully completed 'execute-parallel' request"; - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { +Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) { VLOG(1) << "running execute-graph-parallel request"; std::vector>> all_arguments; @@ -1009,8 +689,7 @@ tensorflow::Status Service::ExecuteGraphParallel( std::unique_ptr module_config, CreateModuleConfig(request.computation().program_shape(), replicated_arguments.front(), - request.execution_options(), - /*user_computation=*/nullptr)); + request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " << module_config->host_entry_computation_layout().ToString(); @@ -1058,11 +737,11 @@ tensorflow::Status Service::ExecuteGraphParallel( } VLOG(1) << "successfully completed 'execute-graph-parallel' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) { +Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) { const int64 available_device_count = execute_backend_->device_count(); const int64 replica_count = options_.number_of_replicas(); if (replica_count <= 0) { @@ -1082,20 +761,11 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, *result->add_device_handles() = device_handle; } - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - return PickParallelResponse(parallel_result, result); + return Status::OK(); } -tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { ExecuteGraphParallelRequest parallel_arg; *parallel_arg.add_requests() = *arg; ExecuteParallelResponse parallel_result; @@ -1103,7 +773,7 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, return PickParallelResponse(parallel_result, result); } -tensorflow::Status Service::PickParallelResponse( +Status Service::PickParallelResponse( const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { // The "result device" selection is a bit hacky, but better than assuming it // is device 0. We have b/76035356 for restructuring the client API to clean @@ -1126,81 +796,6 @@ tensorflow::Status Service::PickParallelResponse( return Status::OK(); } -tensorflow::Status Service::Execute(const ExecuteRequest* arg, - ExecuteResponse* result) { - VLOG(1) << "running execute request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - // If we received multiple device handles, we must partition the module. - if (arg->execution_options().device_handles_size() > 1) { - return ExecuteOneToN(arg, result); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - // Since we care only about the shapes of the arguments, it is sufficient to - // use the arguments of replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "Execute created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), - execute_backend_->default_stream_executor(), - result->mutable_profile())); - - if (executable->dumping()) { - executable->session_module()->set_execution_platform( - execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments( - replicated_arguments.front(), - execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - } - - TF_ASSIGN_OR_RETURN( - *result->mutable_output(), - ExecuteAndRegisterResult( - executable.get(), replicated_arguments, execute_backend_.get(), - "result of " + user_computation->name(), result->mutable_profile())); - - if (executable->dumping()) { - TF_ASSIGN_OR_RETURN( - const ShapedBuffer* result_buffer, - allocation_tracker_.ResolveForReplica(result->output(), 0)); - TF_RETURN_IF_ERROR(RecordResult( - *result_buffer, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - TF_RETURN_IF_ERROR(executable->DumpSessionModule()); - } - - VLOG(1) << "successfully completed 'execute' request"; - return tensorflow::Status::OK(); -} - StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -1243,8 +838,8 @@ StatusOr> Service::BuildExecutable( return std::move(executable); } -tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { VLOG(1) << "running execute-graph request"; if (!arg->has_computation()) { @@ -1303,91 +898,11 @@ tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, } VLOG(1) << "successfully completed 'execute-graph' request"; - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_RET_CHECK(!replicas.empty()); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - ExecutionProfile profile; - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable( - versioned_handle, std::move(module_config), execute_backend_.get(), - execute_backend_->default_stream_executor(), &profile)); - - // Set up streams. - std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, - execute_backend_->BorrowStream(executor)); - streams.push_back(std::move(stream)); - } - - std::vector result_buffers; - for (size_t i = 0; i < streams.size(); ++i) { - const auto& stream = streams[i]; - ExecutableRunOptions options; - options.set_stream(stream.get()); - options.set_allocator(execute_backend_->memory_allocator()); - options.set_intra_op_thread_pool( - execute_backend_->eigen_intra_op_thread_pool_device()); - - ServiceExecutableRunOptions service_options( - options, execute_backend_->StreamBorrower()); - - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer this_result_buffer, - executable->ExecuteAsyncOnStream( - &service_options, replicated_arguments[i])); - - result_buffers.emplace_back(std::move(this_result_buffer)); - } - - TF_ASSIGN_OR_RETURN( - GlobalDataHandle output, - allocation_tracker_.RegisterReplicatedBuffers( - std::move(result_buffers), "result of " + user_computation->name())); - - *result->mutable_execution() = execution_tracker_.Register( - execute_backend_.get(), std::move(streams), profile, output); - streams.clear(); - - VLOG(1) << "successfully completed 'execute-async' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) { +Status Service::WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) { TF_ASSIGN_OR_RETURN(const auto execution, execution_tracker_.Resolve(arg->execution())); @@ -1398,11 +913,11 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution())); VLOG(1) << "successfully completed 'wait-for-execution' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result) { +Status Service::TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); @@ -1432,7 +947,7 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, *result->mutable_literal() = result_literal->Relayout(*return_shape)->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } namespace { @@ -1450,8 +965,8 @@ std::unique_ptr CloneShapedBufferOnDevice( } // namespace -tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) { +Status Service::TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(arg->literal())); const Shape& shape = literal->shape(); @@ -1484,11 +999,11 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, StrCat("TransferToServer literal of shape ", ShapeUtil::HumanString(shape)))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) { +Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1517,9 +1032,8 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor, *literal); } -tensorflow::Status Service::TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) { +Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1545,127 +1059,16 @@ tensorflow::Status Service::TransferFromOutfeed( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( executor, arg->shape_with_layout(), &literal)); *result->mutable_literal() = literal.ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) { +Status Service::ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) { return execute_backend_->ResetDevices(); } -tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->num_parameters())); - - result->set_is_constant(is_constant); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->parameters_size())); - if (!is_constant) { - StatusOr op_request_status = - user_computation->LookUpRequestForErrorReporting(arg->operand()); - string op_request_string = ""; - if (op_request_status.ok()) { - op_request_string = op_request_status.ValueOrDie()->ShortDebugString(); - } - return InvalidArgument( - "Operand to ComputeConstant depends on a parameter.\n\n" - " op requested for constant evaluation: %s\n\n" - "This is an internal error that typically happens when the XLA user " - "(e.g. TensorFlow) is attempting to determine a value that must be a " - "compile-time constant (e.g. an array dimension) but it is not capable " - "of being evaluated at XLA compile time.\n\n" - "Please file a usability bug with the framework being used (e.g. " - "TensorFlow).", - op_request_string.c_str()); - } - - // We can't use ComputeProgramShape because it checks that all parameter - // instructions are present and contiguous. Instead construct ProgramShape - // directly. - ProgramShape program_shape; - TF_ASSIGN_OR_RETURN(*program_shape.mutable_result(), - user_computation->GetShape(arg->operand())); - - TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); - - ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions(); - execution_options.mutable_debug_options()->set_xla_enable_fast_math(false); - execution_options.mutable_debug_options() - ->set_xla_eliminate_hlo_implicit_broadcast(true); - *execution_options.mutable_shape_with_output_layout() = - program_shape.result(); - - Shape shape_with_output_layout(program_shape.result()); - if (arg->has_output_layout()) { - TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), execution_options.shape_with_output_layout())); - *execution_options.mutable_shape_with_output_layout()->mutable_layout() = - arg->output_layout(); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options, - user_computation)); - - // Exclude dead parameter instructions for the purpose of computing constants. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - false)); - - std::vector> parameters(arg->parameters_size()); - for (int64 i = 0; i < arg->parameters_size(); ++i) { - TF_ASSIGN_OR_RETURN(parameters[i], - Literal::CreateFromProto(arg->parameters(i))); - } - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - auto result_literal, - evaluator.Evaluate>(*module, parameters)); - - // Since the shape_with_output_layout option in ExecutionOption is - // non-effective to the Evaluator results, explicit relayout here. - // - // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); - } - *result->mutable_literal() = result_literal->ToProto(); - - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { +Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) { if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } @@ -1703,73 +1106,17 @@ tensorflow::Status Service::ComputeConstantGraph( } *result->mutable_literal() = result_literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) { +Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); *result->mutable_shape() = buffer->on_host_shape(); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( - versioned_handle.version)); - *result->mutable_program_shape() = *program_shape; - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_shape(), - computation->GetShape(arg->operand())); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - HloModuleConfig config; - config.set_debug_options(arg->debug_options()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, config)); - - hlo_graph_dumper::MaybeDumpHloModule(*module, - "computation statistics subject"); - - // Run HLO analysis to get the computation statistics. - HloCostAnalysis analysis( - execute_backend_->compiler()->ShapeSizeBytesFunction()); - - TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); - - ComputationStats stats; - stats.set_flop_count(analysis.flop_count()); - stats.set_transcendental_count(analysis.transcendental_count()); - *result->mutable_stats() = stats; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationGraphStats( +Status Service::GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { if (!arg->has_computation()) { return InvalidArgument("Computations may not be empty."); @@ -1796,264 +1143,7 @@ tensorflow::Status Service::GetComputationGraphStats( stats.set_flop_count(analysis.flop_count()); stats.set_transcendental_count(analysis.transcendental_count()); *result->mutable_stats() = stats; - return tensorflow::Status::OK(); -} - -template -tensorflow::Status Service::AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation)); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - StatusOr handle_status; - - switch (arg->op_case()) { - case OpRequest::kBatchNormTrainingRequest: - handle_status = computation->AddBatchNormTrainingInstruction( - arg->batch_norm_training_request()); - break; - case OpRequest::kBatchNormInferenceRequest: - handle_status = computation->AddBatchNormInferenceInstruction( - arg->batch_norm_inference_request()); - break; - case OpRequest::kBatchNormGradRequest: - handle_status = computation->AddBatchNormGradInstruction( - arg->batch_norm_grad_request()); - break; - case OpRequest::kBinaryOpRequest: - handle_status = - computation->AddBinaryInstruction(arg->binary_op_request()); - break; - case OpRequest::kBroadcastRequest: - handle_status = - computation->AddBroadcastInstruction(arg->broadcast_request()); - break; - case OpRequest::kCallRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->call_request().to_apply())); - handle_status = - computation->AddCallInstruction(arg->call_request(), *to_apply); - break; - } - case OpRequest::kConcatenateRequest: - handle_status = - computation->AddConcatenateInstruction(arg->concatenate_request()); - break; - case OpRequest::kConditionalRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * true_computation, - computation_tracker_.Resolve( - arg->conditional_request().true_computation())); - TF_ASSIGN_OR_RETURN(UserComputation * false_computation, - computation_tracker_.Resolve( - arg->conditional_request().false_computation())); - handle_status = computation->AddConditionalInstruction( - arg->conditional_request(), *true_computation, *false_computation); - break; - } - case OpRequest::kConstantRequest: - handle_status = - computation->AddConstantInstruction(arg->constant_request()); - break; - case OpRequest::kConvertRequest: - handle_status = - computation->AddConvertInstruction(arg->convert_request()); - break; - case OpRequest::kBitcastConvertRequest: - handle_status = computation->AddBitcastConvertInstruction( - arg->bitcast_convert_request()); - break; - case OpRequest::kConvolveRequest: - handle_status = - computation->AddConvolveInstruction(arg->convolve_request()); - break; - case OpRequest::kCrossReplicaSumRequest: - handle_status = computation->AddCrossReplicaSumInstruction( - arg->cross_replica_sum_request()); - break; - case OpRequest::kCustomCallRequest: - handle_status = - computation->AddCustomCallInstruction(arg->custom_call_request()); - break; - case OpRequest::kDotRequest: - handle_status = computation->AddDotInstruction(arg->dot_request()); - break; - case OpRequest::kDynamicSliceRequest: - handle_status = - computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); - break; - case OpRequest::kDynamicUpdateSliceRequest: - handle_status = computation->AddDynamicUpdateSliceInstruction( - arg->dynamic_update_slice_request()); - break; - case OpRequest::kFftRequest: - handle_status = computation->AddFftInstruction(arg->fft_request()); - break; - case OpRequest::kGatherRequest: - handle_status = computation->AddGatherInstruction(arg->gather_request()); - break; - case OpRequest::kGetTupleElementRequest: - handle_status = computation->AddGetTupleElementInstruction( - arg->get_tuple_element_request()); - break; - case OpRequest::kInfeedRequest: - handle_status = computation->AddInfeedInstruction(arg->infeed_request()); - break; - case OpRequest::kOutfeedRequest: - handle_status = - computation->AddOutfeedInstruction(arg->outfeed_request()); - break; - case OpRequest::kHostComputeRequest: - handle_status = - computation->AddHostComputeInstruction(arg->host_compute_request()); - break; - case OpRequest::kMapRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->map_request().to_apply())); - handle_status = - computation->AddMapInstruction(arg->map_request(), *to_apply); - break; - } - case OpRequest::kPadRequest: - handle_status = computation->AddPadInstruction(arg->pad_request()); - break; - case OpRequest::kParameterRequest: - handle_status = - computation->AddParameterInstruction(arg->parameter_request()); - break; - case OpRequest::kReduceRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->reduce_request().to_apply())); - handle_status = - computation->AddReduceInstruction(arg->reduce_request(), *to_apply); - break; - } - case OpRequest::kReducePrecisionRequest: { - handle_status = computation->AddReducePrecisionInstruction( - arg->reduce_precision_request()); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * to_apply, - computation_tracker_.Resolve( - arg->reduce_window_request().to_apply())); - handle_status = computation->AddReduceWindowInstruction( - arg->reduce_window_request(), *to_apply); - break; - } - case OpRequest::kReshapeRequest: - handle_status = - computation->AddReshapeInstruction(arg->reshape_request()); - break; - case OpRequest::kReverseRequest: - handle_status = - computation->AddReverseInstruction(arg->reverse_request()); - break; - case OpRequest::kRngRequest: - handle_status = computation->AddRngInstruction(arg->rng_request()); - break; - case OpRequest::kSelectAndScatterRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * select, - computation_tracker_.Resolve( - arg->select_and_scatter_request().select())); - TF_ASSIGN_OR_RETURN(UserComputation * scatter, - computation_tracker_.Resolve( - arg->select_and_scatter_request().scatter())); - handle_status = computation->AddSelectAndScatterInstruction( - arg->select_and_scatter_request(), *select, *scatter); - break; - } - case OpRequest::kSliceRequest: - handle_status = computation->AddSliceInstruction(arg->slice_request()); - break; - case OpRequest::kTernaryOpRequest: - handle_status = - computation->AddTernaryInstruction(arg->ternary_op_request()); - break; - case OpRequest::kTraceRequest: - return computation->AddTraceInstruction(arg->trace_request()); - case OpRequest::kTransposeRequest: - handle_status = - computation->AddTransposeInstruction(arg->transpose_request()); - break; - case OpRequest::kUnaryOpRequest: - handle_status = computation->AddUnaryInstruction(arg->unary_op_request()); - break; - case OpRequest::kVariadicOpRequest: - handle_status = - computation->AddVariadicInstruction(arg->variadic_op_request()); - break; - case OpRequest::kWhileRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * condition, - computation_tracker_.Resolve(arg->while_request().condition())); - TF_ASSIGN_OR_RETURN( - UserComputation * body, - computation_tracker_.Resolve(arg->while_request().body())); - handle_status = computation->AddWhileInstruction(arg->while_request(), - *condition, *body); - break; - } - case OpRequest::kSendRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterSend(arg->send_request().channel_handle())); - // Send does not return a value, but we need a handle to be able to - // set OpMetadata and OpSharding (device assignment). - handle_status = computation->AddSendInstruction(arg->send_request()); - break; - } - case OpRequest::kRecvRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterRecv(arg->recv_request().channel_handle())); - handle_status = computation->AddRecvInstruction(arg->recv_request()); - break; - } - case OpRequest::OP_NOT_SET: - return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); - default: - return InvalidArgument("Unsupported operation in XLA service"); - } - TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); - - // We set the debug metadata here, because we slice off part of the OpRequest - // proto in the above switch statement. - TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status); - TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata())); - if (arg->has_sharding()) { - TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding())); - } - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.SnapshotComputation(arg->computation())); - - result->set_allocated_module(module.release()); - - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - TF_ASSIGN_OR_RETURN(*result->mutable_computation(), - computation_tracker_.LoadSessionModule(arg->module())); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceHandle Service::SingleComputationDeviceHandle() const { diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index f84fe407e05da371da66ba33efd6e8165198cf2c..8748a4c1447eca691abc0f7ca48feda48ceb86e1 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -26,17 +26,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/allocation_tracker.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" -#include "tensorflow/compiler/xla/service/compilation_cache.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -83,57 +78,29 @@ class Service : public ServiceInterface { static StatusOr> NewService( const ServiceOptions& options); - // Creates a new computation with the given name. - // A unique ComputationHandle is returned. - tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - // Unregisters a previously-allocated global handle. // // If the handle given is not currently allocated, a NOT_FOUND status is // returned. - tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) override; + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each // element in the tuple. - tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) override; - - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - // Executes a computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override; + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; // Executes a computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; - - // Executes one or more computations in parallel with the provided global data - // passed as immutable arguments. Returns global data output for each - // computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) override; // Requests one or more device handles from the target. // @@ -143,49 +110,33 @@ class Service : public ServiceInterface { // the first set of replicas, and the next R devices to the second set of // replicas, etc. Each returned device handle represents the device with the // replica id 0. - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override; - - // Asynchronously executes a computation with provided arguments. Invokes - // the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - // - // (Note: The corresponding function in xla::Client was removed as part of - // b/64116060, in an attempt to simplify our API. We're keeping this around - // for now in case we want to expose this to clients in a different way.) - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; // Waits until the specified execution is complete and returns the result. // Calling this API multiple times with the same execution handle returns the // method with an error since the execution handle is destroyed after the // first call. - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override; + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; // Requests that global data be transferred to the client in literal form. - tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, - TransferToClientResponse* result) override; + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; // Transfers data from a literal provided by the client, into device memory. - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override; + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; // Transfers data from a literal provided by the client, into the Infeed // buffer of the device. - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override; + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; // Transfers data from the Outfeed othe device to the literal provided by the // client. - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override; + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; // Resets devices, clearing all existing state on all the devices associated // with this service (including memory allocated on the devices). @@ -196,77 +147,25 @@ class Service : public ServiceInterface { // ResetDevice should be called before an Execution that expect the device to // be in the reset state. For example, if the prior Execution modifies device // state (e.g., architectural state) that the next Execution depends on. - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override; - - // Tests if an expression is a compile-time constant. - tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; - // Computes the value of a constant expression. - tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Returns the shape (with layout) of an array associated with a given data // handle. - tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) override; - - // Returns the program shape of the computation associated with the given - // handle. - tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ///// - // Computation-oriented methods. - - // Enqueues an Op on the computation. - tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; - - // Retrieves the inferred shape for a value within a computation. - tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; // Retrieves the statistics of a computation. - tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - // Retrieves the statistics of a computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status GetComputationGraphStats( - const ComputationGraphStatsRequest* arg, - ComputationStatsResponse* result) override; - - // Snapshots the current state of a computation handle into a serializable - // protocol buffer form, so it can be loaded via - // LoadComputationSnapshot. - tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - // Loads a computation from a serialized protocol buffer created via - // SnapshotComputation. - tensorflow::Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; + Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, + ComputationStatsResponse* result) override; // Creates a unique channel handle that can be used for Send/Recv // instructions. - tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) override; - - // Returns the ComputationTracker of the current service instance. - // Only used in unit tests to access user computations from client. - const ComputationTracker& computation_tracker() { - return computation_tracker_; - } + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; // Returns the backend used to execute computations. const Backend& backend() const { return *execute_backend_; } @@ -278,8 +177,7 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, - const UserComputation* user_computation = nullptr); + const ExecutionOptions& execution_options); // Picks a parallel response and fills the result. Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, @@ -320,23 +218,13 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation* user_computation = nullptr); + const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. // // If device_allocator is not null, the compiler may use it to allocate temp // buffers, which the compiler is responsible for freeing. The allocator // given here need not match the allocator used when running the executable. - StatusOr> BuildExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, - DeviceMemoryAllocator* device_allocator = nullptr); - - // Builds an Executable for the given HLO module proto. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -345,26 +233,12 @@ class Service : public ServiceInterface { // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. - StatusOr>> BuildExecutables( - std::vector versioned_handles, - std::vector> module_configs, - Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator); StatusOr>> BuildExecutables( const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator); - // Similar to BuildExecutable, but look in the compilation cache for the - // executable first. If the executable is not in the cache, it is built and - // inserted into the cache. - StatusOr> BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator = nullptr); - // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is // returned. If the parameter "profile" is not null, it points to an @@ -387,26 +261,16 @@ class Service : public ServiceInterface { tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile); - // Convenience function for adding a function to a user computation. - template - tensorflow::Status AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder); - // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which // will be the result of this computation. - tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result); - tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result); + Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. - tensorflow::Status ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const; + Status ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const; // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that @@ -422,9 +286,6 @@ class Service : public ServiceInterface { ServiceOptions options_; - // Tracks computations built via the API. - ComputationTracker computation_tracker_; - // Tracks channels created via the API. ChannelTracker channel_tracker_; @@ -434,9 +295,6 @@ class Service : public ServiceInterface { // Tracks asynchronously launched executions via the API. ExecutionTracker execution_tracker_; - // Cache containing previously built Executables. - CompilationCache compilation_cache_; - // Backend to compile and execute computations on. std::unique_ptr execute_backend_; diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto deleted file mode 100644 index bb8d1cd2a106ea3e5bb61eee5052bd60c38cd0e2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/session.proto +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This proto file defines messages which store the state of XLA -// computations within the XLA service. A computation is stored as a record -// of the operation requests used to build it. -syntax = "proto3"; - -import "tensorflow/compiler/xla/xla_data.proto"; - -package xla; - -// Describes a single operation request. -message OperationRequest { - ComputationDataHandle output_handle = 1; - Shape output_shape = 2; - - // For operations which call embedded computations such as "Map", these are - // the version(s) that the embedded computation should be called at. A version - // value of a computation is the ComputationDataHandle of the root of the - // computation at the point in time. - // - // "Call", "Map", "Reduce", and "ReduceWindow" operations take a single - // embedded computation so this field will have a single value for those - // operations. - // - // "While" operation takes two; index 0 is the "condition" version and index 1 - // is the "body" version. - repeated int64 embedded_computation_versions = 3; - - // The actual request, which in itself is a tagged union of all possible - // operation request types. - OpRequest request = 4; -} - -// Describes a sequence of operation requests which define an XLA -// computation. -message SessionComputation { - string name = 1; - - // The ComputationHandle used to refer to this computation in the XLA - // service. - ComputationHandle computation_handle = 2; - - // Map from ComputationDataHandle value to operation request. The highest - // ComputationDataHandle value corresponds to the root of the computation. - map requests = 3; -} - -// Describes a group of SessionComputations with an "entry point" computation -// that may refer to the other non-entry (AKA embedded) computations. -// -// This message is used to serialize a computation that has been built via the -// XLA service API, along with its dependencies, for purposes such as -// analysis/replay/file-storage. -message SessionModule { - // The entry computation, which was requested for serialization. This may have - // referred to embedded computations, which are reflected below. - SessionComputation entry = 1; - - // Embedded computations that are transitively referred to by the entry - // computation. - repeated SessionComputation embedded_computations = 2; - - // The arguments passed to the computation. - repeated LiteralProto arguments = 3; - - // The result of the computation. - LiteralProto result = 4; - - // The name of the platform used to run the computation. - string execution_platform = 5; -} diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index c493547d9e83e19c09329f32873de3a9a330b460..bd98e86b08b7507b4a7a0d1a7ebac4b654ff2171 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -44,132 +44,13 @@ namespace xla { namespace { -// Return the UnaryOperation proto enum value associated with the given HLO -// opcode. -UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAbs: - return UNOP_ABS; - case HloOpcode::kCeil: - return UNOP_CEIL; - case HloOpcode::kClz: - return UNOP_CLZ; - case HloOpcode::kCos: - return UNOP_COS; - case HloOpcode::kExp: - return UNOP_EXP; - case HloOpcode::kFloor: - return UNOP_FLOOR; - case HloOpcode::kImag: - return UNOP_IMAG; - case HloOpcode::kIsFinite: - return UNOP_IS_FINITE; - case HloOpcode::kLog: - return UNOP_LOG; - case HloOpcode::kNot: - return UNOP_NOT; - case HloOpcode::kNegate: - return UNOP_NEGATE; - case HloOpcode::kReal: - return UNOP_REAL; - case HloOpcode::kRoundNearestAfz: - return UNOP_ROUND_NEAREST_AFZ; - case HloOpcode::kSign: - return UNOP_SIGN; - case HloOpcode::kSin: - return UNOP_SIN; - case HloOpcode::kSort: - return UNOP_SORT; - case HloOpcode::kTanh: - return UNOP_TANH; - default: - LOG(FATAL) << "Unhandled opcode for conversion to unary operation: " - << opcode; - } -} - -// Return the BinaryOperation proto enum value associated with the given HLO -// opcode. -BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAtan2: - return BINOP_ATAN2; - case HloOpcode::kComplex: - return BINOP_COMPLEX; - case HloOpcode::kMultiply: - return BINOP_MUL; - case HloOpcode::kAdd: - return BINOP_ADD; - case HloOpcode::kSubtract: - return BINOP_SUB; - case HloOpcode::kDivide: - return BINOP_DIV; - case HloOpcode::kEq: - return BINOP_EQ; - case HloOpcode::kGe: - return BINOP_GE; - case HloOpcode::kGt: - return BINOP_GT; - case HloOpcode::kLe: - return BINOP_LE; - case HloOpcode::kLt: - return BINOP_LT; - case HloOpcode::kNe: - return BINOP_NE; - case HloOpcode::kMaximum: - return BINOP_MAX; - case HloOpcode::kMinimum: - return BINOP_MIN; - case HloOpcode::kPower: - return BINOP_POW; - case HloOpcode::kRemainder: - return BINOP_REM; - case HloOpcode::kOr: - return BINOP_OR; - case HloOpcode::kAnd: - return BINOP_AND; - case HloOpcode::kShiftLeft: - return BINOP_SHIFT_LEFT; - case HloOpcode::kShiftRightArithmetic: - return BINOP_SHIFT_RIGHT_ARITHMETIC; - case HloOpcode::kShiftRightLogical: - return BINOP_SHIFT_RIGHT_LOGICAL; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the TernaryOperation proto enum value associated with the given HLO -// opcode. -TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kClamp: - return TRIOP_CLAMP; - case HloOpcode::kSelect: - return TRIOP_SELECT; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the VariadicOperation proto enum value associated with the given HLO -// opcode. -VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kTuple: - return VAROP_TUPLE; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - // Returns true if no element is present in slice more than once. bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, - tensorflow::StringPiece op_type) { +Status ExpectNotTupleOrOpaque(const Shape& shape, + tensorflow::StringPiece op_type) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("Expected non-tuple argument for %s, but got %s.", std::string(op_type).c_str(), @@ -179,13 +60,13 @@ tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else { - return tensorflow::Status::OK(); + return Status::OK(); } } -tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, - const Shape& init_value_shape, - const PrimitiveType& input_element_type) { +Status VerifyReducerShape(const ProgramShape& reducer_shape, + const Shape& init_value_shape, + const PrimitiveType& input_element_type) { if (reducer_shape.parameters_size() != 2) { return InvalidArgument( "Reduction function must take 2 parameters, but " @@ -245,7 +126,7 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, ShapeUtil::HumanString(accumulator_shape).c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr InferWindowOutputShape(const Shape& base_shape, @@ -312,86 +193,86 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const Shape& shape) { // There is no copy operation at the proto level, so handle copy explicitly. - if (opcode == HloOpcode::kCopy) { + // A domain shape is the same as the input one. + if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) { return shape; } - return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), shape); -} + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(shape, "operand of unary operation")); -/* static */ StatusOr ShapeInference::InferUnaryOpShape( - UnaryOperation operation, const Shape& arg) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation")); - - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg)); - switch (operation) { - case UNOP_FLOOR: - case UNOP_CEIL: - if (!ShapeUtil::ElementIsFloating(arg)) { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + switch (opcode) { + case HloOpcode::kFloor: + case HloOpcode::kCeil: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( "Expected element type in shape to be floating for floor/ceil " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_COS: - case UNOP_SIN: - case UNOP_EXP: - case UNOP_LOG: - case UNOP_TANH: - if (!ShapeUtil::ElementIsFloating(arg) && - !ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kCos: + case HloOpcode::kSin: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kTanh: + if (!ShapeUtil::ElementIsFloating(shape) && + !ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( "Expected element type in shape to be floating or complex for " "sin/cos/exp/log/tanh operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_REAL: - case UNOP_IMAG: - if (!ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kReal: + case HloOpcode::kImag: + if (!ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( "Expected element type in shape to be complex for real/imag " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, F32); - case UNOP_ABS: - if (ShapeUtil::ElementIsComplex(arg)) { + return ShapeUtil::ChangeElementType(shape, F32); + case HloOpcode::kAbs: + if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( - arg, primitive_util::ComplexComponentType(arg.element_type())); + shape, primitive_util::ComplexComponentType(shape.element_type())); } - return arg; - case UNOP_CLZ: - case UNOP_NEGATE: - case UNOP_ROUND_NEAREST_AFZ: - case UNOP_SIGN: - case UNOP_SORT: - return arg; - - case UNOP_NOT: - if (arg.element_type() != PRED && - !primitive_util::IsIntegralType(arg.element_type())) { + return shape; + case HloOpcode::kClz: + case HloOpcode::kNegate: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSign: + case HloOpcode::kSort: + return shape; + + case HloOpcode::kNot: + if (shape.element_type() != PRED && + !primitive_util::IsIntegralType(shape.element_type())) { return InvalidArgument( "Expected pred or an integral element type in argument to Not " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; + return shape; - case UNOP_IS_FINITE: - if (!ShapeUtil::ElementIsFloating(arg)) { + case HloOpcode::kIsFinite: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "Expected element type in shape to be floating point for IsFinite " + "Expected element type in shape to be floating " + "point for IsFinite " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, PRED); + return ShapeUtil::ChangeElementType(shape, PRED); default: return InvalidArgument( "Unknown operation for unary shape inference: \"%s\".", - UnaryOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -456,6 +337,17 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } +/* static */ StatusOr ShapeInference::InferTokenShape( + tensorflow::gtl::ArraySlice arg_shapes) { + for (const Shape* arg_shape : arg_shapes) { + if (arg_shape->element_type() != TOKEN) { + return InvalidArgument( + "Operands of token instructions must be TOKEN types."); + } + } + return ShapeUtil::MakeTokenShape(); +} + /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); @@ -761,8 +653,9 @@ Status ValidateDotDimensionNumbers( } /* static */ StatusOr -ShapeInference::InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs) { +ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, + const Shape& lhs, + const Shape& rhs) { TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)); // The shapes have to be compatible. That is, if some dimension d has a @@ -780,7 +673,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", - BinaryOperation_Name(operation).c_str(), + HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -790,8 +683,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions) { if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring @@ -892,7 +784,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation")); @@ -902,8 +794,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Binary op %s with different element types: %s and %s.", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), + HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -936,10 +827,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs; // After InDim broadcasting, perform degenerate dimensions broadcasting. - TF_ASSIGN_OR_RETURN( - Shape indim_broadcast_shape, - InferInDimBroadcastShape(operation, smaller_shape, larger_shape, - broadcast_dimensions)); + TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape, + InferInDimBroadcastShape(smaller_shape, larger_shape, + broadcast_dimensions)); return InferDegenerateDimensionBroadcastShape( operation, indim_broadcast_shape, larger_shape); @@ -948,51 +838,44 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) { - return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(), - rhs->shape(), /*broadcast_dimensions=*/{}); + return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(), + /*broadcast_dimensions=*/{}); } /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs, - broadcast_dimensions); -} - -/* static */ StatusOr ShapeInference::InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { VLOG(2) << tensorflow::strings::Printf( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), + HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str(), Join(broadcast_dimensions, ", ").c_str()); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( lhs, tensorflow::strings::StrCat("lhs of binary operation ", - BinaryOperation_Name(operation)))); + HloOpcodeString(opcode)))); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( rhs, tensorflow::strings::StrCat("rhs of binary operation ", - BinaryOperation_Name(operation)))); - switch (operation) { - case BINOP_MAX: - case BINOP_MIN: - case BINOP_SUB: - case BINOP_ADD: - case BINOP_ATAN2: - case BINOP_POW: - case BINOP_DIV: - case BINOP_REM: - case BINOP_MUL: - case BINOP_SHIFT_LEFT: - case BINOP_SHIFT_RIGHT_ARITHMETIC: - case BINOP_SHIFT_RIGHT_LOGICAL: - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + HloOpcodeString(opcode)))); + switch (opcode) { + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kSubtract: + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kPower: + case HloOpcode::kDivide: + case HloOpcode::kRemainder: + case HloOpcode::kMultiply: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_COMPLEX: { + case HloOpcode::kComplex: { if (!ShapeUtil::ElementIsFloating(lhs)) { return InvalidArgument( "Expected element type in shape to be floating for complex compose " @@ -1000,7 +883,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(lhs.element_type()).c_str()); } TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); @@ -1008,8 +891,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return Unimplemented("Complex component type is not implemented."); } } - case BINOP_AND: - case BINOP_OR: + case HloOpcode::kAnd: + case HloOpcode::kOr: if (lhs.element_type() != PRED && !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( @@ -1017,24 +900,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "got %s.", PrimitiveType_Name(lhs.element_type()).c_str()); } - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_EQ: - case BINOP_GE: - case BINOP_GT: - case BINOP_LE: - case BINOP_LT: - case BINOP_NE: { + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: { TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); return ShapeUtil::ChangeElementType(shape, PRED); } default: return Unimplemented( "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", - BinaryOperation_Name(operation).c_str(), - lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str()); + HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(), + rhs.ShortDebugString().c_str()); } } @@ -1046,23 +929,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) { - return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs); -} - -/* static */ StatusOr ShapeInference::InferTernaryOpShape( - TernaryOperation operation, const Shape& lhs, const Shape& rhs, - const Shape& ehs) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs)); - switch (operation) { - case TRIOP_CLAMP: + switch (opcode) { + case HloOpcode::kClamp: return InferClampShape(lhs, rhs, ehs); - case TRIOP_SELECT: + case HloOpcode::kSelect: return InferSelectShape(lhs, rhs, ehs); default: return InvalidArgument("Unknown operation %s.", - TernaryOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -1079,18 +956,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes) { - return InferVariadicOpShape(OpcodeToVariadicOperation(opcode), - operand_shapes); -} - -/* static */ StatusOr ShapeInference::InferVariadicOpShape( - VariadicOperation operation, - tensorflow::gtl::ArraySlice operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } - switch (operation) { - case VAROP_TUPLE: { + switch (opcode) { + case HloOpcode::kTuple: { Shape result = ShapeUtil::MakeTupleShape({}); for (const Shape* shape : operand_shapes) { ShapeUtil::AppendShapeToTuple(*shape, &result); @@ -1099,7 +969,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } default: return InvalidArgument("Unknown operation %s.", - VariadicOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -1212,11 +1082,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm training")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( @@ -1318,15 +1188,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm inference")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 9da2c99b4177f08ece8daabaf2922ddd7e947a1b..f1f7b50902d899c0c629c3098d80fc400fb1388d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -46,8 +46,6 @@ class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the // given input shape. - static StatusOr InferUnaryOpShape(UnaryOperation operation, - const Shape& arg); static StatusOr InferUnaryOpShape(HloOpcode opcode, const Shape& shape); static StatusOr InferUnaryOpShape(HloOpcode opcode, @@ -55,9 +53,6 @@ class ShapeInference { // Infers the shape produced by applying the given binary operation to the // given input shapes. - static StatusOr InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); static StatusOr InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); @@ -67,9 +62,6 @@ class ShapeInference { // Infers the shape produced by applying the given ternary operation to the // given input shapes. - static StatusOr InferTernaryOpShape(TernaryOperation operation, - const Shape& lhs, const Shape& rhs, - const Shape& ehs); static StatusOr InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs); @@ -80,9 +72,6 @@ class ShapeInference { // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. - static StatusOr InferVariadicOpShape( - VariadicOperation operation, - tensorflow::gtl::ArraySlice operand_shapes); static StatusOr InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes); @@ -227,6 +216,13 @@ class ShapeInference { static StatusOr InferConcatOpShape( tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); + // Infers the shape produced by a kGenerateToken operation. Trivially this + // shape is always a TOKEN shape. However, ShapeInference serves two purposes: + // inferring shapes and checking operand shapes. This method verifies that the + // operand shapes are all TOKENs. + static StatusOr InferTokenShape( + tensorflow::gtl::ArraySlice arg_shapes); + // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. @@ -279,7 +275,7 @@ class ShapeInference { // the LHS and a single element in the RHS to produce a single output element, // even in the presence of broadcasting of one of the operands over the other. static StatusOr InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); // Helper for inferring the shape of Clamp ops. @@ -295,7 +291,7 @@ class ShapeInference { // dimension broadcasting (a dimension of size 1 in one operand is broadcast // up to match the size of the dimension in the other operand). static StatusOr InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs); + HloOpcode operation, const Shape& lhs, const Shape& rhs); // Helper for inferring shapes of binary operations using "InDim" // broadcasting. This is the broadcasting used in the *InDim binary operations @@ -303,8 +299,7 @@ class ShapeInference { // lower-rank shape than larger_shape. Returns the shape that the // smaller_shape is broadcast to. static StatusOr InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions); TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 0e61994a786b53a295ef9c9c2287b28fbf754d9b..6d017dffe2d8f927abad4a62bff7fe41bc871975 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -101,8 +101,8 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = ShapeInference::InferUnaryOpShape( - UnaryOperation::UNOP_NEGATE, matrix_shape); + auto inferred_status = + ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie())); } @@ -110,14 +110,14 @@ TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple); + HloOpcode::kSelect, pred_, tuple, tuple); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } @@ -125,34 +125,34 @@ TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { auto predarray = ShapeUtil::MakeShape(PRED, {64, 48}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectBadShapes) { auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("pred operand must have PRED")); auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), - matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, + matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("with non-scalar predicate with dimensionality")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}), + HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}), ShapeUtil::MakeTupleShape({f32_, f32_}), ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); @@ -162,102 +162,98 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { TEST_F(ShapeInferenceTest, ClampAllMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, - matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampAllScalar) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_); + auto inferred_status = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_); + HloOpcode::kClamp, matrix_64_48_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_); + HloOpcode::kClamp, f32_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_); + HloOpcode::kClamp, f32_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampBadShapes) { // Type mismatch - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_) - .ok()); - // Dimension mismatch ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_64_, vector_32_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_64_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_32_, vector_64_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_) .ok()); - // Dimension mismatch, where one operand is a scalar + // Dimension mismatch ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_) + HloOpcode::kClamp, vector_64_, vector_32_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_64_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_32_, vector_64_) + .ok()); + // Dimension mismatch, where one operand is a scalar + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, vector_32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, f32_, vector_32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, + vector_64_, vector_32_) .ok()); } TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, const tensorflow::gtl::ArraySlice& bcast) { - return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX, - lhs, rhs, bcast); + return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, + bcast); }; // Inputs must be FP. ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok()); @@ -292,8 +288,8 @@ TEST_F(ShapeInferenceTest, Complex) { } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { - StatusOr result = ShapeInference::InferVariadicOpShape( - VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); + StatusOr result = + ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_}); ASSERT_IS_OK(result.status()); ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(), ShapeUtil::MakeTupleShape({s32_, f32_}))); @@ -804,8 +800,8 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) { TEST_F(ShapeInferenceTest, InferPowShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {}); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kPower, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); } @@ -813,7 +809,7 @@ TEST_F(ShapeInferenceTest, InferPowShape) { TEST_F(ShapeInferenceTest, InferCompareShapeEq) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -822,7 +818,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeEq) { TEST_F(ShapeInferenceTest, InferCompareShapeGe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -831,7 +827,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGe) { TEST_F(ShapeInferenceTest, InferCompareShapeGt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -840,7 +836,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGt) { TEST_F(ShapeInferenceTest, InferCompareShapeLe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -849,7 +845,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLe) { TEST_F(ShapeInferenceTest, InferCompareShapeLt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -858,7 +854,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLt) { TEST_F(ShapeInferenceTest, InferCompareShapeNe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -1111,22 +1107,22 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); const Shape vec16 = ShapeUtil::MakeShape(F32, {16}); - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {1}); + auto inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {0}); + auto inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0}); ASSERT_FALSE(inferred_status_mismatch.ok()); - inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {0}); + inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {1}); + inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1}); ASSERT_FALSE(inferred_status_mismatch.ok()); } @@ -1138,17 +1134,17 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8}); auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2}); + HloOpcode::kAdd, cube, matrix8_4, {1, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2}); + HloOpcode::kAdd, cube, matrix16_4, {0, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1}); + HloOpcode::kAdd, cube, matrix16_8, {0, 1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); } @@ -1162,43 +1158,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8}); // "magical" broadcast rejected - auto inferred_status_error1 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {}); + auto inferred_status_error1 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Automatic")); // broadcast_dimension out of bounds for tensor's rank - auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {3}); + auto inferred_status_error2 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension - auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {0}); + auto inferred_status_error3 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); + HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); + HloOpcode::kAdd, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); + HloOpcode::kAdd, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), HasSubstr("dimension 0 mismatch")); @@ -1207,13 +1203,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { // in a proper (strictly increasing) order, even if the lower-rank array // matches the higher-rank array in many different ways. auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); ASSERT_THAT(inferred_status_error7.status().error_message(), HasSubstr("dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); ASSERT_THAT(inferred_status_error8.status().error_message(), HasSubstr("dimensions order is wrong")); diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index fb3b5f06dad67b4305aed0305c9f6441e666db53..7d7dcac10b65933d1c81b8aca77465932694bfdb 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include #include #include @@ -25,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -123,6 +123,8 @@ ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s) } ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { + Deallocate(); + *static_cast(this) = std::move(static_cast(s)); allocator_ = s.allocator_; // Null out s.allocator_ so it doesn't try to free anything in its destructor. @@ -130,7 +132,15 @@ ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { return *this; } -ScopedShapedBuffer::~ScopedShapedBuffer() { +ScopedShapedBuffer::~ScopedShapedBuffer() { Deallocate(); } + +ShapedBuffer ScopedShapedBuffer::release() { + ShapedBuffer shaped_buffer(static_cast(*this)); + buffers_ = ShapeTree(); + return shaped_buffer; +} + +void ScopedShapedBuffer::Deallocate() { // allocator_ will be null if we were moved-from. if (allocator_ == nullptr) { return; @@ -138,22 +148,14 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. - std::set deallocated_opaques; + tensorflow::gtl::FlatSet deallocated_ptrs; for (auto& pair : buffers_) { se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && - deallocated_opaques.count(memory_base.opaque()) == 0) { - deallocated_opaques.insert(memory_base.opaque()); - TF_CHECK_OK( - this->allocator_->Deallocate(this->device_ordinal(), &memory_base)); + deallocated_ptrs.insert(memory_base.opaque()).second) { + TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base)); } } } -ShapedBuffer ScopedShapedBuffer::release() { - ShapedBuffer shaped_buffer(static_cast(*this)); - buffers_ = ShapeTree(); - return shaped_buffer; -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index e10fca9e9466c018f6cb4da2f5618e4db4977307..905a7e82e621f2bf4588b71be5dbab20f892cafe 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -148,13 +148,29 @@ class ScopedShapedBuffer : public ShapedBuffer { // ScopedShapedBuffer. DeviceMemoryAllocator* memory_allocator() const { return allocator_; } - // Releases all device memory owned by this ScopedShapedBuffer and returns the - // device memory pointers in the form of a ShapedBuffer. The returned - // ShapedBuffer takes over the memory from the ScopedShapedBuffer. The - // resulting ScopedShapedBuffer can only be destroyed. - ShapedBuffer release(); + // Sets the device memory buffer at the given index. + // + // If the given buffer's device memory is non-null, its device_ordinal and + // allocator must match those in `this`. + void set_buffer(OwningDeviceMemory buffer, const ShapeIndex& index) { + if (!buffer.is_null()) { + CHECK_EQ(buffer.device_ordinal(), device_ordinal()); + CHECK_EQ(buffer.allocator(), allocator_); + *buffers_.mutable_element(index) = buffer.Forget(); + } else { + *buffers_.mutable_element(index) = se::DeviceMemoryBase(); + } + } + + // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from + // this ScopedShapedBuffer, without freeing any of the associated memory. + // + // It's the caller's job to ensure that the memory contained therein is freed. + TF_MUST_USE_RESULT ShapedBuffer release(); protected: + void Deallocate(); + DeviceMemoryAllocator* allocator_; }; diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0fc243667911651c788e3c1e5f1d39d86170f1ad --- /dev/null +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/shaped_buffer.h" + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { +namespace { + +TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { + TF_ASSERT_OK_AND_ASSIGN(auto platforms, + xla::PlatformUtil::GetSupportedPlatforms()); + ASSERT_FALSE(platforms.empty()); + auto* platform = platforms[0]; + TF_ASSERT_OK_AND_ASSIGN(auto executors, + xla::PlatformUtil::GetStreamExecutors(platform)); + xla::StreamExecutorMemoryAllocator allocator(platform, executors); + const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + const int kDeviceOrdinal = 0; + auto scoped_buffer = tensorflow::MakeUnique( + shape, shape, &allocator, kDeviceOrdinal); + std::unique_ptr buffer = std::move(scoped_buffer); + buffer = nullptr; +} + +class TestAllocator : public DeviceMemoryAllocator { + public: + TestAllocator() + : DeviceMemoryAllocator(PlatformUtil::GetDefaultPlatform().ValueOrDie()) { + } + + ~TestAllocator() override { + if (!allocations_.empty()) { + ADD_FAILURE() << "Some allocations not freed!"; + } + } + + // Pull in two-arg overload of Allocate. + using DeviceMemoryAllocator::Allocate; + + StatusOr Allocate(int device_ordinal, uint64 size, + bool /*retry_on_failure*/) override { + // By contract, we must return null if size == 0. + if (size == 0) { + return OwningDeviceMemory(); + } + void* buf = malloc(size); + allocations_.insert({device_ordinal, buf}); + return OwningDeviceMemory(se::DeviceMemoryBase(buf, size), device_ordinal, + this); + } + + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override { + if (mem.is_null()) { + return Status::OK(); + } + + auto it = allocations_.find({device_ordinal, mem.opaque()}); + if (it == allocations_.end()) { + ADD_FAILURE() << "Allocation not found (double free?)"; + } else { + free(mem.opaque()); + allocations_.erase(it); + } + return Status::OK(); + } + + bool AllowsAsynchronousDeallocation() const override { return false; } + + private: + std::set> allocations_; +}; + +TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { + Shape s = ShapeUtil::MakeShape(F32, {1}); + TestAllocator allocator; + ScopedShapedBuffer sb1(s, s, &allocator, /*device_ordinal=*/0); + sb1.set_buffer( + allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(), + /*index=*/{}); + + ScopedShapedBuffer sb2(s, s, &allocator, /*device_ordinal=*/1); + sb2.set_buffer( + allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(), + /*index=*/{}); + + sb1 = std::move(sb2); + + // TestAllocator's destructor checks that all memory was freed. +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 8b71a415091f028b3167cddb2583754e72ba17c8..c4d01562c4e32225ebb984d8fcd93ec3fa86e403 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -37,7 +37,7 @@ TransferManager::GetPlatformTransferManagers() { } Status TransferManager::TransferArrayToDevice( - se::StreamExecutor* executor, const Literal& literal, + se::StreamExecutor* executor, const LiteralSlice& literal, const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) @@ -196,9 +196,11 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( const ShapeIndex& index = pair.first; se::DeviceMemoryBase& memory_base = pair.second; const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index); - TF_ASSIGN_OR_RETURN(memory_base, + TF_ASSIGN_OR_RETURN(auto memory, allocator->Allocate(shaped_buffer.device_ordinal(), GetByteSizeRequirement(subshape))); + // Move the allocated buffer into the ScopedShapedBuffer, which owns it. + memory_base = memory.Forget(); } return std::move(shaped_buffer); diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index d82b4f0f81b5da38c1caf80bddefa0d3f7842463..43a8092b06fba0e2495bce0ee1a309c85a908273 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -65,14 +65,14 @@ class TransferManager { // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, // but need not have the same layout virtual Status TransferLiteralToDevice(se::StreamExecutor* executor, - const Literal& literal, + const LiteralSlice& literal, const ShapedBuffer& device_buffer) = 0; // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to // transfer an array at a known address. Status TransferArrayToDevice(se::StreamExecutor* executor, - const Literal& literal, + const LiteralSlice& literal, const se::DeviceMemoryBase& dest); StatusOr> TransferArrayFromDevice( se::StreamExecutor* executor, const Shape& shape, @@ -81,7 +81,7 @@ class TransferManager { // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) = 0; + const LiteralSlice& literal) = 0; // Transfers the given literal from the Outfeed interface of the device, // using the given executor. diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index f7a5512fec47f75a72d31464ebac556ae41b36b9..ba16dc640e2d2974eab4fc8b134a6e33c03e3b85 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -215,7 +215,7 @@ StatusOr TransposeFolding::Run(HloModule* module) { std::make_pair(instruction, operand_indices)); } } - return tensorflow::Status::OK(); + return Status::OK(); }; for (auto* comp : module->MakeNonfusionComputations()) { diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index f73f1227aaf1630a9e7c43bb508732c5518ef929..cccb8f2fbb0266bbf1f40b09170938a1e5d3e78d 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -27,12 +27,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" @@ -69,7 +69,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); @@ -91,7 +91,7 @@ ENTRY entry_computation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TransposeFolding transpose_folding( [](const HloInstruction& dot, @@ -119,7 +119,7 @@ ENTRY entry_computation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TransposeFolding transpose_folding( [](const HloInstruction& dot, @@ -147,7 +147,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); @@ -176,7 +176,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build(mul)); HloInstruction* call = module->OutlineExpressionFromComputation( - {add, sub, mul}, "", entry_computation); + {add, sub, mul}, "entry", entry_computation); EXPECT_EQ(call, entry_computation->root_instruction()); HloComputation* callee_computation = call->to_apply(); // The arguments to the call should be const1, const2, and const3. @@ -205,7 +205,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); const HloComputation* callee = module->GetComputationWithName("callee"); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 657a8fe09ae9df906d695f7f49df72500d611792..eb6d1ada6b553f998fe06917dfdf0b5092cd79cd 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -273,6 +273,14 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand, so just copy the operands points-to + // set. + CreateCopiedPointsToSet(domain, domain->operand(0)); + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { // A kSlice instruction aliases its operand if the backend lowers it to an // in-place implementation. @@ -588,4 +596,206 @@ void TuplePointsToAnalysis::InstructionToString( }); } +bool TuplePointsToAnalysis::DoesNotUseOperandBuffer( + const HloInstruction* operand, const ShapeIndex& index, + const HloInstruction* user) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { + // GetTupleElement instructions only access the top-level buffer of their + // operand. + return true; + } else if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + auto it = std::find_if( + user->fused_parameters().begin(), user->fused_parameters().end(), + [=](HloInstruction* fused_param) { + return user->operand(fused_param->parameter_number()) == operand; + }); + CHECK(it != user->fused_parameters().end()); + // Iterate through all users of all buffer aliases of the buffer in the + // points-to set of fusion parameter at 'index'. + // Return false if any uses are detected at 'index', returns true otherwise. + const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie(); + for (const BufferAlias& alias : GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user)) { + continue; + } + // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. + return false; + } + } + // Return true: found no uses of 'operand' at 'index' in 'user'. + return true; + } + return false; +} + +// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. +// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) +// where 'user' is a user of an alias of 'instruction' at 'index', and +// 'operand_index' is the operand index at which the alias appears in the +// operand list of 'user'. +std::vector> +TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index) const { + std::vector> uses; + const PointsToSet::BufferList& points_to = + GetPointsToSet(instruction).element(index); + for (const LogicalBuffer* buffer : points_to) { + for (const BufferAlias& alias : GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user)) { + continue; + } + for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { + uses.emplace_back(alias_user, op_idx); + } + } + } + } + return uses; +} + +// Returns true if there is exactly one use of 'operand' at 'operand_index' +// in 'fusion.fused_instructions', where the singleton use is the fused +// root at operand index 'use_operand_index'. Returns false otherwise. +// +// REQUIRES: 'fusion' opcode is a kFusion instruction. +bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* fusion, const int64 use_operand_index) const { + CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); + // Check that 'operand' is unique in the operand list of 'fusion'. + if (fusion->OperandIndices(operand).size() > 1) { + return false; + } + // Find fusion parameter associated with 'operand'. + const auto& fused_params = fusion->fused_parameters(); + auto fused_param_it = std::find_if( + fused_params.begin(), fused_params.end(), + [&](HloInstruction* fused_param) { + return fusion->operand(fused_param->parameter_number()) == operand; + }); + if (fused_param_it == fused_params.end()) { + return false; + } + auto* fused_param = *fused_param_it; + // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. + auto fused_param_uses = + GetAllUsesOfInstructionAtIndex(fused_param, operand_index); + // Return true iff there is exactly one use of 'operand' at 'index', and + // this singleton use is the fused root (at index in 'use_operand_indices'). + return fused_param_uses.size() == 1 && + fused_param_uses[0].first == fusion->fused_expression_root() && + fused_param_uses[0].second == use_operand_index; +} + +// User and operand can share buffers iff both instructions emit the same shape +// and layout, and 'user' meets one of the following qualifications: +// +// (1) Is element-wise. Or... +// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' +// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root +// at operand 0. Or... +// (3) Is a kDot -> kAdd output fusion instruction where the only use of +// 'operand' at 'index' in the set 'user.fused_instructions' is a kAdd fused +// root at operand 0 or 1. Or... +// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index +// 0. +// +// (2) and (3) can only be determined if points-to analysis is available. +bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + const Shape& operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + if (user->opcode() == HloOpcode::kFusion) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + } + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is kDot or kConvolution. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, + other_add_operand_index); + } + } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + if (user->opcode() == HloOpcode::kCall) { + // TODO(b/62548313): Remove when buffer assignment is module scoped and + // does not assign buffers to calls. + // Find called computation parameter associated with 'operand'. + const std::vector operand_indices = user->OperandIndices(operand); + if (operand_indices.size() > 1) { + return false; + } + CHECK_EQ(1, operand_indices.size()); + auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); + // Get all uses of 'operand' at 'index' in called computation. + auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index); + + // Return true iff: + // *) There exists exactly one use of 'operand' in called computation. + // *) The unique use is by the root instruction of called computation. + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + auto* callee_root = user->to_apply()->root_instruction(); + return param_uses.size() == 1 && param_uses[0].first == callee_root && + callee_root->IsElementwiseOnOperand(param_uses[0].second); + } + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + // + // Multi-output fusion will fail the check here as tuples are not considered + // an elementwise operation. + return user->IsElementwiseOnOperand(user->operand_index(operand)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index c3743b150168ebcf1051050dc511e50c43108c4f..c0d82414806d9a6ff57aec59d077f444137fec9a 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -248,6 +248,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; @@ -256,6 +257,23 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { string ToString() const; + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + // Returns true if 'user' (at 'user_index') can share a buffer with its + // operand 'operand' (at 'operand_index'). Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index) const; + private: explicit TuplePointsToAnalysis( const HloModule* module, @@ -310,6 +328,13 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { return &per_instruction_[id]; } + std::vector> GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index) const; + bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* fusion, + const int64 use_operand_index) const; + // The module this analysis is performed on. const HloModule* module_; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index dec446d4dac650ba43992f7870764eedc80cb2cf..5734f284071944bc22011405898cf86f33dc48d7 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -805,5 +805,373 @@ TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) { Run(/*add_additional_gte0_user=*/true); } +class PointsToAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr points_to_analysis_; +}; + +class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0)); + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1)); + EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0)); + EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE( + points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); +} + +class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {}, + result, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {}, + result, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {0}, + fusion, {})); + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {1}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {})); + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser( + add_operand, {}, fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = CreateNewModule(); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {})); +} + +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(reverse, {}, + call, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, LoopFusionWithElementwiseOperand) { + Shape full_shape = ShapeUtil::MakeShape(F32, {16, 32}); + Shape broadcast_shape = ShapeUtil::MakeShape(F32, {16}); + + auto builder = HloComputation::Builder(TestName() + "_fusion"); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, full_shape, "full")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, broadcast_shape, "small")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(full_shape, param1, {0})); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + full_shape, HloOpcode::kAdd, param0, broadcast)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, broadcast}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc index 113c2e2bd9f73a2b0c783103d7f2da9534bc97c3..77bdcc9de0d830991208a1db271d009bccaf550e 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -30,10 +30,17 @@ limitations under the License. namespace xla { +TupleSimplifier::TupleSimplifier(bool exclude_entry_computation) : + exclude_entry_computation_(exclude_entry_computation) {} + StatusOr TupleSimplifier::Run(HloModule* module) { // Initially add all GTE and Tuple instructions to the worklist. std::queue worklist; for (auto* computation : module->computations()) { + if (exclude_entry_computation_ && + computation == module->entry_computation()) { + continue; + } for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kTuple || instruction->opcode() == HloOpcode::kGetTupleElement) { @@ -78,7 +85,6 @@ StatusOr TupleSimplifier::Run(HloModule* module) { can_simplify = false; break; } - if (top_tuple == nullptr) { top_tuple = operand->mutable_operand(0); if (!ShapeUtil::Compatible(top_tuple->shape(), @@ -108,10 +114,10 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // | // GTE if (instruction->operand(0)->opcode() == HloOpcode::kTuple) { - changed = true; HloInstruction* element_source = instruction->mutable_operand(0)->mutable_operand( instruction->tuple_index()); + changed = true; TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); for (HloInstruction* user : element_source->users()) { if (user->opcode() == HloOpcode::kTuple || diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index e5e9b10b5bf3f452d1bfec476b8d5c7d74c4f4e8..750950188312c5077d487f2feef0606f07839432 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -27,13 +27,20 @@ namespace xla { // the module. class TupleSimplifier : public HloPassInterface { public: - TupleSimplifier() {} + TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} + explicit TupleSimplifier(bool exclude_entry_computation); ~TupleSimplifier() override {} tensorflow::StringPiece name() const override { return "tuple-simplifier"; } // Run tuple simplification on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + private: + // When set, this pipeline stage will perform optimization of all computations + // apart from the module's entry computation. This is used by Graphcore's + // backend. + bool exclude_entry_computation_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index ca9ae91281fce5ee061d066fc3e538dbbc09f6b3..d3635eae81ec7017f9bf6a69250d10716309c9ec 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -42,6 +42,12 @@ class TupleSimplifierTest : public HloTestBase { TF_ASSERT_OK(changed_status.status()); EXPECT_EQ(change_expected, changed_status.ValueOrDie()); } + void Run(HloModule* module, bool change_expected, bool exclude_entry) { + TupleSimplifier simplifier(exclude_entry); + auto changed_status = simplifier.Run(module); + TF_ASSERT_OK(changed_status.status()); + EXPECT_EQ(change_expected, changed_status.ValueOrDie()); + } const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( @@ -211,5 +217,76 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); } +TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { + // Verify that the root computation can be excluded + auto module = CreateNewModule(); + + HloInstruction* p0; + HloInstruction* p1; + HloComputation* c0; + HloComputation* c1; + HloComputation* entry; + + { + HloComputation::Builder builder(TestName() + "_1"); + p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c0 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_2"); + p1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c1 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_Entry"); + HloInstruction* tuple_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* call0 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0)); + HloInstruction* call1 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1)); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0)); + HloInstruction* gte3 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3})); + + entry = module->AddEntryComputation(builder.Build()); + } + + Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + + EXPECT_THAT(c0->root_instruction(), p0); + EXPECT_THAT(c1->root_instruction(), p1); + EXPECT_THAT(entry->instruction_count(), 9); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc index 754fd8ef169231827eeb5bfd72aeb596644ca767..d33d5bb8f30c8504aa323d461e5f59709b48e1fc 100644 --- a/tensorflow/compiler/xla/service/tuple_util_test.cc +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -37,7 +37,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc deleted file mode 100644 index 0f16a592b68e20f5dbd1e4655ad5720ecce5a7bd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ /dev/null @@ -1,3553 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace xla { -namespace { - -HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { - switch (unop) { - case UNOP_ABS: - return HloOpcode::kAbs; - case UNOP_CEIL: - return HloOpcode::kCeil; - case UNOP_CLZ: - return HloOpcode::kClz; - case UNOP_COS: - return HloOpcode::kCos; - case UNOP_EXP: - return HloOpcode::kExp; - case UNOP_FLOOR: - return HloOpcode::kFloor; - case UNOP_IMAG: - return HloOpcode::kImag; - case UNOP_IS_FINITE: - return HloOpcode::kIsFinite; - case UNOP_LOG: - return HloOpcode::kLog; - case UNOP_NOT: - return HloOpcode::kNot; - case UNOP_NEGATE: - return HloOpcode::kNegate; - case UNOP_REAL: - return HloOpcode::kReal; - case UNOP_ROUND_NEAREST_AFZ: - return HloOpcode::kRoundNearestAfz; - case UNOP_SIGN: - return HloOpcode::kSign; - case UNOP_SIN: - return HloOpcode::kSin; - case UNOP_SORT: - return HloOpcode::kSort; - case UNOP_TANH: - return HloOpcode::kTanh; - default: - LOG(FATAL) << "unhandled operation " << unop; - } -} - -HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { - switch (binop) { - case BINOP_ATAN2: - return HloOpcode::kAtan2; - case BINOP_COMPLEX: - return HloOpcode::kComplex; - case BINOP_MUL: - return HloOpcode::kMultiply; - case BINOP_ADD: - return HloOpcode::kAdd; - case BINOP_SUB: - return HloOpcode::kSubtract; - case BINOP_DIV: - return HloOpcode::kDivide; - case BINOP_EQ: - return HloOpcode::kEq; - case BINOP_GE: - return HloOpcode::kGe; - case BINOP_GT: - return HloOpcode::kGt; - case BINOP_LE: - return HloOpcode::kLe; - case BINOP_LT: - return HloOpcode::kLt; - case BINOP_NE: - return HloOpcode::kNe; - case BINOP_MAX: - return HloOpcode::kMaximum; - case BINOP_MIN: - return HloOpcode::kMinimum; - case BINOP_POW: - return HloOpcode::kPower; - case BINOP_REM: - return HloOpcode::kRemainder; - case BINOP_OR: - return HloOpcode::kOr; - case BINOP_AND: - return HloOpcode::kAnd; - case BINOP_SHIFT_LEFT: - return HloOpcode::kShiftLeft; - case BINOP_SHIFT_RIGHT_ARITHMETIC: - return HloOpcode::kShiftRightArithmetic; - case BINOP_SHIFT_RIGHT_LOGICAL: - return HloOpcode::kShiftRightLogical; - default: - LOG(FATAL) << "unhandled operation " << binop; - } -} - -HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) { - switch (triop) { - case TRIOP_CLAMP: - return HloOpcode::kClamp; - case TRIOP_SELECT: - return HloOpcode::kSelect; - default: - LOG(FATAL) << "unhandled operation " << triop; - } -} - -HloOpcode VariadicOperationToHloOpcode(VariadicOperation varop) { - switch (varop) { - case VAROP_TUPLE: - return HloOpcode::kTuple; - default: - LOG(FATAL) << "unhandled operation " << varop; - } -} - -} // namespace - -/* static */ StatusOr> -UserComputation::MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new) { - auto user_computation = - MakeUnique(session_computation.name(), handle); - { - tensorflow::mutex_lock lock(user_computation->mutex_); - user_computation->session_computation_ = session_computation; - user_computation->next_handle_value_ = - std::max_element(session_computation.requests().begin(), - session_computation.requests().end(), - [](const std::pair& lhs, - const std::pair& rhs) { - return lhs.first < rhs.first; - }) - ->first + - 1; - TF_RETURN_IF_ERROR(user_computation->RemapEmbeddedComputations(old_to_new)); - } - - return std::move(user_computation); -} - -UserComputation::UserComputation(const string& name, - const ComputationHandle& handle) - : name_(name), next_handle_value_(1) { - *session_computation_.mutable_computation_handle() = handle; - session_computation_.set_name(name); - - VLOG(1) << "New UserComputation \"" << name - << "\", handle: " << handle.handle(); -} - -ComputationDataHandle UserComputation::CreateComputationDataHandle() { - ComputationDataHandle handle; - handle.set_handle(next_handle_value_); - // Handles are used as Version values and *must* be assigned consecutively for - // computation versioning to work. - next_handle_value_++; - return handle; -} - -StatusOr UserComputation::AddParameterInstruction( - const ParameterRequest& parameter_request) { - tensorflow::mutex_lock lock(mutex_); - - int64 parameter_number = parameter_request.parameter(); - if (parameters_.count(parameter_number) != 0) { - return InvalidArgument("parameter %lld already registered", - parameter_number); - } - ComputationDataHandle handle = CreateComputationDataHandle(); - - const Shape& validated_shape = parameter_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_parameter_request() = parameter_request; - - parameters_[parameter_number] = &request; - - VLOG(1) << "AddParameterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << parameter_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSendInstruction( - const SendRequest& send_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check if the operand of the instruction is valid. - TF_RETURN_IF_ERROR(LookUpRequest(send_request.operand()).status()); - - // No handle is returned, but a handle must be assigned to this instruction - // for computation versioning. - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_send_request() = send_request; - - VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << send_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddRecvInstruction( - const RecvRequest& recv_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = recv_request.shape(); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_recv_request() = recv_request; - - VLOG(1) << "AddRecvInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << recv_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddPadInstruction( - const PadRequest& pad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(pad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value, - LookUpRequest(pad_request.padding_value())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape( - operand->output_shape(), - padding_value->output_shape(), - pad_request.padding_config())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_pad_request() = pad_request; - - VLOG(1) << "AddPadInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << pad_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConstantInstruction( - const ConstantRequest& constant_request) { - const Shape& validated_shape = constant_request.literal().shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - tensorflow::mutex_lock lock(mutex_); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_constant_request() = constant_request; - - VLOG(1) << "AddConstantInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle(); - return handle; -} - -StatusOr UserComputation::AddGatherInstruction( - const GatherRequest& gather_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* input_request, - LookUpRequest(gather_request.input())); - TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request, - LookUpRequest(gather_request.gather_indices())); - - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferGatherShape( - input_request->output_shape(), gather_indices_request->output_shape(), - gather_request.dimension_numbers(), - AsInt64Slice(gather_request.window_bounds()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_gather_request() = gather_request; - - VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << gather_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(get_tuple_element_request.operand())); - if (!ShapeUtil::IsTuple(operand->output_shape())) { - return InvalidArgument( - "Operand to GetTupleElement() is not a tuple; got %s", - ShapeUtil::HumanString(operand->output_shape()).c_str()); - } - Shape element_shape = ShapeUtil::GetTupleElementShape( - operand->output_shape(), get_tuple_element_request.index()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = element_shape; - *request.mutable_request()->mutable_get_tuple_element_request() = - get_tuple_element_request; - - VLOG(1) << "AddGetTupleElementInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << get_tuple_element_request.ShortDebugString(); - return handle; -} - -Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the operand index is valid. - TF_RETURN_IF_ERROR(LookUpRequest(trace_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_trace_request() = trace_request; - - VLOG(1) << "AddTraceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << trace_request.ShortDebugString(); - return Status::OK(); -} - -StatusOr UserComputation::AddRngInstruction( - const RngRequest& rng_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check the number of parameters per RNG distribution. - switch (rng_request.distribution()) { - case RandomDistribution::RNG_NORMAL: - case RandomDistribution::RNG_UNIFORM: - if (rng_request.parameter_size() != 2) { - return InvalidArgument( - "RNG distribution (%s) expects 2 parameters, but got %d", - RandomDistribution_Name(rng_request.distribution()).c_str(), - rng_request.parameter_size()); - } - break; - default: - LOG(FATAL) << "unhandled distribution " << rng_request.distribution(); - } - - // Verify that the parameter indices are valid; - for (const ComputationDataHandle& param : rng_request.parameter()) { - TF_RETURN_IF_ERROR(LookUpRequest(param).status()); - } - const Shape& validated_shape = rng_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_rng_request() = rng_request; - - VLOG(1) << "AddRngInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << rng_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : map_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape, - AsInt64Slice(map_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_map_request() = map_request; - - VLOG(1) << "AddMapInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << map_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceShape( - operand->output_shape(), init_value->output_shape(), - AsInt64Slice(reduce_request.dimensions()), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_request() = reduce_request; - - VLOG(1) << "AddReduceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_training_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_training_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_training_request.offset())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormTrainingShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), batch_norm_training_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_training_request() = - batch_norm_training_request; - - VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_training_request.ShortDebugString(); - - return handle; -} - -StatusOr -UserComputation::AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_inference_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_inference_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_inference_request.offset())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_inference_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_inference_request.variance())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBatchNormInferenceShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), mean->output_shape(), - variance->output_shape(), - batch_norm_inference_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_inference_request() = - batch_norm_inference_request; - - VLOG(1) << "AddBatchNormInferenceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << batch_norm_inference_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_grad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_grad_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_grad_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_grad_request.variance())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* grad_output, - LookUpRequest(batch_norm_grad_request.grad_output())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormGradShape( - operand->output_shape(), scale->output_shape(), mean->output_shape(), - variance->output_shape(), grad_output->output_shape(), - batch_norm_grad_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_grad_request() = - batch_norm_grad_request; - - VLOG(1) << "AddBatchNormGradInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_grad_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_window_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_window_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceWindowShape( - operand->output_shape(), init_value->output_shape(), - reduce_window_request.window(), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_window_request() = - reduce_window_request; - - VLOG(1) << "AddReduceWindowInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_window_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(select_and_scatter_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* source, - LookUpRequest(select_and_scatter_request.source())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(select_and_scatter_request.init_value())); - - VersionedComputationHandle::Version select_version = - select_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr select_program_shape, - select_computation.ComputeProgramShape(select_version)); - VersionedComputationHandle::Version scatter_version = - scatter_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr scatter_program_shape, - scatter_computation.ComputeProgramShape(scatter_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferSelectAndScatterShape( - operand->output_shape(), *select_program_shape, - select_and_scatter_request.window(), source->output_shape(), - init_value->output_shape(), *scatter_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(select_version); - request.add_embedded_computation_versions(scatter_version); - *request.mutable_request()->mutable_select_and_scatter_request() = - select_and_scatter_request; - - VLOG(1) << "AddSelectAndScatterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << select_and_scatter_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReverseInstruction( - const ReverseRequest& reverse_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reverse_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReverseShape( - operand->output_shape(), AsInt64Slice(reverse_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reverse_request() = reverse_request; - VLOG(1) << "AddReverseInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reverse_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* init, - LookUpRequest(while_request.init())); - - VersionedComputationHandle::Version condition_version = - condition_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr condition_program_shape, - condition_computation.ComputeProgramShape(condition_version)); - - VersionedComputationHandle::Version body_version = body_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr body_program_shape, - body_computation.ComputeProgramShape(body_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferWhileShape( - *condition_program_shape, *body_program_shape, init->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(condition_version); - request.add_embedded_computation_versions(body_version); - *request.mutable_request()->mutable_while_request() = while_request; - - VLOG(1) << "AddWhileInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << while_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* pred, - LookUpRequest(conditional_request.predicate())); - TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, - LookUpRequest(conditional_request.true_operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, - LookUpRequest(conditional_request.false_operand())); - - VersionedComputationHandle::Version true_computation_version = - true_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr true_computation_shape, - true_computation.ComputeProgramShape(true_computation_version)); - - VersionedComputationHandle::Version false_computation_version = - false_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr false_computation_shape, - false_computation.ComputeProgramShape(false_computation_version)); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferConditionalShape( - pred->output_shape(), true_operand->output_shape(), - false_operand->output_shape(), - *true_computation_shape, *false_computation_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(true_computation_version); - request.add_embedded_computation_versions(false_computation_version); - *request.mutable_request()->mutable_conditional_request() = - conditional_request; - - VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << conditional_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBroadcastInstruction( - const BroadcastRequest& broadcast_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(broadcast_request.operand())); - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBroadcastShape( - operand->output_shape(), - AsInt64Slice(broadcast_request.broadcast_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_broadcast_request() = broadcast_request; - - VLOG(1) << "AddBroadcastInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << broadcast_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReshapeInstruction( - const ReshapeRequest& reshape_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reshape_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReshapeShape( - operand->output_shape(), AsInt64Slice(reshape_request.dimensions()), - AsInt64Slice(reshape_request.new_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reshape_request() = reshape_request; - - VLOG(1) << "AddReshapeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reshape_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTransposeInstruction( - const TransposeRequest& transpose_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(transpose_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferTransposeShape( - operand->output_shape(), - AsInt64Slice(transpose_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_transpose_request() = transpose_request; - - VLOG(1) << "AddTransposeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << transpose_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSliceInstruction( - const SliceRequest& slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(slice_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferSliceShape( - operand->output_shape(), AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_slice_request() = slice_request; - - VLOG(1) << "AddSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices, - LookUpRequest(dynamic_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferDynamicSliceShape( - operand->output_shape(), start_indices->output_shape(), - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_slice_request() = - dynamic_slice_request; - - VLOG(1) << "AddDynamicSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dynamic_slice_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_update_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* update, - LookUpRequest(dynamic_update_slice_request.update())); - - TF_ASSIGN_OR_RETURN( - const OperationRequest* start_indices, - LookUpRequest(dynamic_update_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferDynamicUpdateSliceShape( - operand->output_shape(), update->output_shape(), - start_indices->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_update_slice_request() = - dynamic_update_slice_request; - - VLOG(1) << "AddDynamicUpdateSliceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << dynamic_update_slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : concatenate_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferConcatOpShape( - operand_shapes, concatenate_request.dimension())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_concatenate_request() = - concatenate_request; - - VLOG(1) << "AddConcatenateInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << concatenate_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_convert_request() = convert_request; - - VLOG(1) << "AddConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBitcastConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_bitcast_convert_request() = - convert_request; - - VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_precision_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferReducePrecisionShape( - operand->output_shape(), reduce_precision_request.exponent_bits(), - reduce_precision_request.mantissa_bits())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_reduce_precision_request() = - reduce_precision_request; - - VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_precision_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvolveInstruction( - const ConvolveRequest& convolve_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(convolve_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(convolve_request.rhs())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( - lhs->output_shape(), rhs->output_shape(), - convolve_request.window(), - convolve_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_convolve_request() = convolve_request; - - VLOG(1) << "AddConvolveInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convolve_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddFftInstruction( - const FftRequest& fft_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(fft_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferFftShape( - operand->output_shape(), fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_fft_request() = fft_request; - - VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << fft_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(cross_replica_sum_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - {&operand->output_shape()})); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_cross_replica_sum_request() = - cross_replica_sum_request; - - VLOG(1) << "AddCrossreplicaSumInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << cross_replica_sum_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddInfeedInstruction( - const InfeedRequest& infeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = infeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Infeed must have a layout"); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_infeed_request() = infeed_request; - - VLOG(1) << "AddInfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << infeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddOutfeedInstruction( - const OutfeedRequest& outfeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = outfeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Outfeed must have a layout"); - } - - // Verify that operand is valid. - TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_outfeed_request() = outfeed_request; - - VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << outfeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : call_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferCallShape(operand_shapes, *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_call_request() = call_request; - - VLOG(1) << "AddCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCustomCallInstruction( - const CustomCallRequest& custom_call_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : custom_call_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - if (tensorflow::str_util::StartsWith(custom_call_request.call_target_name(), - "$")) { - return InvalidArgument( - "Invalid custom_call_target \"%s\": Call targets that start with '$' " - "are reserved for internal use.", - custom_call_request.call_target_name().c_str()); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = custom_call_request.shape(); - *request.mutable_request()->mutable_custom_call_request() = - custom_call_request; - - VLOG(1) << "AddCustomCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << custom_call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddHostComputeInstruction( - const HostComputeRequest& host_compute_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : host_compute_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = host_compute_request.shape(); - *request.mutable_request()->mutable_host_compute_request() = - host_compute_request; - - VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << host_compute_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDotInstruction( - const DotRequest& dot_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(dot_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(dot_request.rhs())); - - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( - lhs->output_shape(), rhs->output_shape(), - dot_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_dot_request() = dot_request; - - VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dot_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddUnaryInstruction( - const UnaryOpRequest& unary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(unary_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(), - operand->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_unary_op_request() = unary_request; - - VLOG(1) << "AddUnaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << unary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBinaryInstruction( - const BinaryOpRequest& binary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(binary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(binary_request.rhs())); - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferBinaryOpShape( - binary_request.binop(), lhs->output_shape(), rhs->output_shape(), - AsInt64Slice(binary_request.broadcast_dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_binary_op_request() = binary_request; - - VLOG(1) << "AddBinaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << binary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTernaryInstruction( - const TernaryOpRequest& ternary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(ternary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(ternary_request.rhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* ehs, - LookUpRequest(ternary_request.ehs())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferTernaryOpShape( - ternary_request.triop(), lhs->output_shape(), - rhs->output_shape(), ehs->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_ternary_op_request() = ternary_request; - - VLOG(1) << "AddTernaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << ternary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddVariadicInstruction( - const VariadicOpRequest& variadic_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : variadic_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferVariadicOpShape( - variadic_request.varop(), operand_shapes)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_variadic_op_request() = variadic_request; - - VLOG(1) << "AddVariadicInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << variadic_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::GetShape(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - return operand->output_shape(); -} - -Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpMetadata (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_metadata() = metadata; - return Status::OK(); -} - -Status UserComputation::SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpSharding (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_sharding() = sharding; - return Status::OK(); -} - -Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - if (!(handle.handle() > 0 && handle.handle() < next_handle_value_)) { - return InvalidArgument("Invalid handle in SetReturnValue"); - } - - handle_to_return_ = handle; - - VLOG(1) << "SetReturnValue of computation \"" << name() << "\" fixed to " - << GetVersionedHandleInternal(); - - return Status::OK(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandle() const { - tensorflow::mutex_lock lock(mutex_); - return GetVersionedHandleInternal(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandleInternal() const { - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - - if (handle_to_return_.handle() > 0) { - // A specific handle has been requested for the result of the computation. - versioned_handle.version = handle_to_return_.handle(); - } else { - // A version value is simply the most recently assigned - // ComputationDataHandle value, ie the handle value of the root of the - // computation. - versioned_handle.version = next_handle_value_ - 1; - } - - return versioned_handle; -} - -VersionedComputationHandle UserComputation::GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const { - tensorflow::mutex_lock lock(mutex_); - - // The version at which an operation was added is simply the handle value of - // the ComputationDataHandle. - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - versioned_handle.version = operation.handle(); - return versioned_handle; -} - -VersionedComputationHandle::Version UserComputation::version() const { - return GetVersionedHandle().version; -} - -namespace { - -// Returns true if the operation type corresponding to the given opcase can be -// the root of the computation. -bool CanBeRoot(const OpRequest::OpCase& op_case) { - switch (op_case) { - case OpRequest::kTraceRequest: - case OpRequest::kSendRequest: - case OpRequest::kOutfeedRequest: - return false; - default: - return true; - } -} - -// Returns a pointer to the operation with the given data handle value in the -// given SessionComputation. -StatusOr LookUpRequest( - int64 handle_value, const SessionComputation& session_computation) { - if (session_computation.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation.requests().at(handle_value); -} - -// Returns the OperationRequest corresponding to the root (result) of the -// session computation. -StatusOr GetRoot( - VersionedComputationHandle::Version version, - const SessionComputation& session_computation) { - TF_RET_CHECK(version > 0); - // Not all instructions can be roots. Walk backwards from the operation - // indicated by this version until a valid root is found. - const OperationRequest* root_request = nullptr; - while (version > 0) { - TF_ASSIGN_OR_RETURN(root_request, - LookUpRequest(version, session_computation)); - if (CanBeRoot(root_request->request().op_case())) { - break; - } - version--; - } - if (version == 0) { - return InternalError("Computation contains no root operation"); - } - return root_request; -} - -} // namespace - -StatusOr> -UserComputation::ComputeProgramShape( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - if (program_shape_ == nullptr || program_shape_version_ != version) { - // ProgramShape has not been computed yet, or is for different - // version. Compute it now. - TF_RETURN_IF_ERROR(CheckParametersAreContiguous(version)); - - auto program_shape = MakeUnique(); - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - int64 param_no = parameter_request.parameter(); - // Parameters may be out of order so expand ProgramShape parameters - // until it is at least large enough to hold the current parameter - // number. - while (program_shape->parameters_size() <= param_no) { - program_shape->add_parameters(); - program_shape->add_parameter_names(); - } - *program_shape->mutable_parameters(param_no) = request.output_shape(); - *program_shape->mutable_parameter_names(param_no) = - parameter_request.name(); - } - } - - // The root determines the output shape. - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version, session_computation_)); - *program_shape->mutable_result() = root_request->output_shape(); - if (ShapeUtil::IsOpaque(program_shape->result())) { - return Unimplemented("Computation results cannot be opaque"); - } - - program_shape_ = std::move(program_shape); - program_shape_version_ = version; - } - - return program_shape_; -} - -namespace { - -// A visitor which checks whether an operation is pure functional meaning that -// it doesn't depend on any parameter with an index higher then num_parameters. -// The visitor walks the computation starting at a given operation and sets -// is_functional to false iff a parameter or RNG operation is encountered. -void PureFunctionalVisitor(const SessionComputation& session_computation, - const ComputationDataHandle& handle, - int64 num_parameters, std::set* visited, - bool* is_functional) { - if (visited->count(handle.handle()) != 0 || !*is_functional) { - return; - } - - const OperationRequest& request = - session_computation.requests().at(handle.handle()); - switch (request.request().op_case()) { - case OpRequest::kRngRequest: - *is_functional = false; - break; - - case OpRequest::kConstantRequest: - break; - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - PureFunctionalVisitor(session_computation, - get_tuple_element_request.operand(), num_parameters, - visited, is_functional); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - PureFunctionalVisitor(session_computation, slice_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.update(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - PureFunctionalVisitor(session_computation, convolve_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, convolve_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - PureFunctionalVisitor(session_computation, fft_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - // TODO(b/33009255): Implmement constant folding for cross replica sum. - *is_functional = false; - break; - } - - case OpRequest::kInfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kOutfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kHostComputeRequest: { - *is_functional = false; - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself, - // so we conservatively say that computations containing the Call op - // cannot be constant. We cannot set is_functional=false in other similar - // cases since we're already relying on IsConstant to return true. - *is_functional = false; - break; - } - - case OpRequest::kCustomCallRequest: { - *is_functional = false; - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - PureFunctionalVisitor(session_computation, dot_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, dot_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kSendRequest: { - *is_functional = false; - break; - } - - case OpRequest::kRecvRequest: { - *is_functional = false; - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - PureFunctionalVisitor(session_computation, reduce_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, reduce_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - PureFunctionalVisitor(session_computation, - reduce_window_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - reduce_window_request.init_value(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.source(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the select and scatter - // computations themselves. - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - PureFunctionalVisitor(session_computation, broadcast_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - PureFunctionalVisitor(session_computation, reshape_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - PureFunctionalVisitor(session_computation, reverse_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - PureFunctionalVisitor(session_computation, pad_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, pad_request.padding_value(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - if (parameter_request.parameter() >= num_parameters) { - *is_functional = false; - } - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - PureFunctionalVisitor(session_computation, while_request.init(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the condition and body - // computations themselves. - *is_functional = false; - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - PureFunctionalVisitor(session_computation, - conditional_request.predicate(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.true_operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.false_operand(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the true and false computations - // themselves. - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - PureFunctionalVisitor(session_computation, ternary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.rhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.ehs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - PureFunctionalVisitor(session_computation, transpose_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - PureFunctionalVisitor(session_computation, unary_op_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.offset(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.scale(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.offset(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.mean(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.variance(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.variance(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.grad_output(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - PureFunctionalVisitor(session_computation, binary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, binary_op_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kGatherRequest: { - PureFunctionalVisitor(session_computation, - request.request().gather_request().input(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - request.request().gather_request().gather_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - if (!*is_functional) { - VLOG(1) << "Non-functional: " << request.request().DebugString(); - } - visited->insert(handle.handle()); -} - -} // namespace - -StatusOr UserComputation::IsConstant(const ComputationDataHandle& handle, - int64 num_parameters) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the handle is valid. - auto operation_status = LookUpRequest(handle); - if (!operation_status.ok()) { - return operation_status.status(); - } - - bool is_constant = true; - std::set visited; - PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited, - &is_constant); - - return is_constant; -} - -std::vector -UserComputation::GetEmbeddedComputations( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(1) - << "GetEmbeddedComputations(" << name() << " " - << VersionedComputationHandle{session_computation_.computation_handle(), - version} - << ")"; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - std::vector computations; - std::vector sorted_handles; - for (const auto& handle_request : session_computation_.requests()) { - sorted_handles.push_back(handle_request.first); - } - std::sort(sorted_handles.begin(), sorted_handles.end()); - for (int64 handle : sorted_handles) { - const auto& handle_request = session_computation_.requests().find(handle); - CHECK(handle_request != session_computation_.requests().end()); - int64 handle_value = handle_request->first; - if (handle_value <= version) { - const OperationRequest& request = handle_request->second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const CallRequest& call_request = request.request().call_request(); - const VersionedComputationHandle versioned_handle = { - call_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kMapRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const MapRequest& map_request = request.request().map_request(); - const VersionedComputationHandle versioned_handle = { - map_request.to_apply(), request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceRequest& reduce_request = - request.request().reduce_request(); - const VersionedComputationHandle versioned_handle = { - reduce_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceWindowRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - const VersionedComputationHandle versioned_handle = { - reduce_window_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - const VersionedComputationHandle select_versioned_handle = { - select_and_scatter_request.select(), - request.embedded_computation_versions(0)}; - computations.push_back(select_versioned_handle); - const VersionedComputationHandle scatter_versioned_handle = { - select_and_scatter_request.scatter(), - request.embedded_computation_versions(1)}; - computations.push_back(scatter_versioned_handle); - break; - } - - case OpRequest::kWhileRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const WhileRequest& while_request = request.request().while_request(); - const VersionedComputationHandle condition_versioned_handle = { - while_request.condition(), - request.embedded_computation_versions(0)}; - computations.push_back(condition_versioned_handle); - const VersionedComputationHandle body_versioned_handle = { - while_request.body(), request.embedded_computation_versions(1)}; - computations.push_back(body_versioned_handle); - break; - } - - case OpRequest::kConditionalRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - const VersionedComputationHandle true_computation_versioned_handle = { - conditional_request.true_computation(), - request.embedded_computation_versions(0)}; - computations.push_back(true_computation_versioned_handle); - const VersionedComputationHandle false_computation_versioned_handle = - {conditional_request.false_computation(), - request.embedded_computation_versions(1)}; - computations.push_back(false_computation_versioned_handle); - break; - } - - default: - // No embedded computation. - break; - } - } - } - VLOG(2) << "Embedded computations: " - << tensorflow::str_util::Join( - computations, ", ", - [](string* out, const VersionedComputationHandle& h) { - out->append(h.ToString()); - }); - return computations; -} - -StatusOr -UserComputation::LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const { - tensorflow::mutex_lock lock(mutex_); - return LookUpRequest(handle); -} - -tensorflow::gtl::optional UserComputation::ParameterMetadata( - int parameter_number) const { - tensorflow::mutex_lock lock(mutex_); - auto it = parameters_.find(parameter_number); - if (it == parameters_.end()) { - return tensorflow::gtl::nullopt; - } - OperationRequest* op = it->second; - return &op->request().metadata(); -} - -Status UserComputation::RemapEmbeddedComputations( - const std::map& old_to_new) { - auto update = [&old_to_new](ComputationHandle* to_update) -> Status { - int64 old = to_update->handle(); - auto it = old_to_new.find(old); - if (it == old_to_new.end()) { - string mapping = tensorflow::str_util::Join( - old_to_new, ", ", - [](string* out, std::pair element) { - tensorflow::strings::Appendf(out, "%lld:%lld", element.first, - element.second.handle()); - }); - return NotFound( - "could not find referenced (old) computation handle in mapping: " - "%lld; mapping: {%s}", - old, mapping.c_str()); - } - VLOG(2) << "remapping " << old << " to " << it->second.handle(); - *to_update = it->second; - return Status::OK(); - }; - TF_RETURN_IF_ERROR(update(session_computation_.mutable_computation_handle())); - for (auto& handle_request : *session_computation_.mutable_requests()) { - OperationRequest& request = handle_request.second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - CallRequest* call_request = - request.mutable_request()->mutable_call_request(); - TF_RETURN_IF_ERROR(update(call_request->mutable_to_apply())); - break; - } - case OpRequest::kMapRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - MapRequest* map_request = - request.mutable_request()->mutable_map_request(); - TF_RETURN_IF_ERROR(update(map_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceRequest* reduce_request = - request.mutable_request()->mutable_reduce_request(); - TF_RETURN_IF_ERROR(update(reduce_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceWindowRequest* reduce_window_request = - request.mutable_request()->mutable_reduce_window_request(); - TF_RETURN_IF_ERROR(update(reduce_window_request->mutable_to_apply())); - break; - } - case OpRequest::kSelectAndScatterRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - SelectAndScatterRequest* select_and_scatter_request = - request.mutable_request()->mutable_select_and_scatter_request(); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_select())); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_scatter())); - break; - } - case OpRequest::kWhileRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - WhileRequest* while_request = - request.mutable_request()->mutable_while_request(); - TF_RETURN_IF_ERROR(update(while_request->mutable_condition())); - TF_RETURN_IF_ERROR(update(while_request->mutable_body())); - break; - } - case OpRequest::kConditionalRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - ConditionalRequest* conditional_request = - request.mutable_request()->mutable_conditional_request(); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_true_computation())); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_false_computation())); - break; - } - default: - // No embedded computation. - TF_RET_CHECK(0 == request.embedded_computation_versions_size()); - break; - } - } - return Status::OK(); -} - -SessionComputation UserComputation::CloneSessionComputation( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - SessionComputation result = session_computation_; - // Erase all the requests that exceed the version specified. - // There's no lower_bound method on tensorflow::protobuf::Map so we iterate - // all the elements. - auto it = result.mutable_requests()->begin(); - while (it != result.mutable_requests()->end()) { - if (it->first > version) { - it = result.mutable_requests()->erase(it); - } else { - ++it; - } - } - return result; -} - -StatusOr UserComputation::LookUpRequest( - const ComputationDataHandle& handle) const { - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation_.requests().at(handle_value); -} - -Status UserComputation::CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const { - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - // Determine number of parameter inputs at the given version. - std::map parameter_requests; - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - // Duplicate parameters should be checked when parameter requests are - // added. - TF_RET_CHECK(0 == - parameter_requests.count(parameter_request.parameter())); - parameter_requests[parameter_request.parameter()] = ¶meter_request; - } - } - - for (int64 i = 0; i < parameter_requests.size(); ++i) { - auto it = parameter_requests.find(i); - if (it == parameter_requests.end()) { - return FailedPrecondition( - "computation %s does not have all its parameters populated " - "sequentially, missing parameter %lld", - name_.c_str(), i); - } - } - - return Status::OK(); -} - -namespace { - -// Helper class which builds an HLO computation from a SessionComputation. To -// construct the HLO computation, the SessionComputation graph is walked in -// DFS order lowering each OperationRequest to an HLO instruction. -class ComputationLowerer { - public: - static StatusOr> Lower( - const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) { - ComputationLowerer lowerer(computation_name, session_computation, version, - std::move(hlo_resolver), debug_options, - include_unreachable_instructions); - return lowerer.Lower(); - } - - private: - ComputationLowerer(const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) - : hlo_builder_(computation_name), - session_computation_(session_computation), - version_(version), - hlo_resolver_(std::move(hlo_resolver)), - debug_options_(debug_options), - include_unreachable_instructions_(include_unreachable_instructions) {} - - // Build an HLO computation from the SessionComputation at the given - // version. - StatusOr> Lower(); - - private: - // Traverses the computation 'root' using a DFS, calling 'visit' in postorder. - void TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit); - - // DFS visitor of the UserComputation operations which lowers the operations - // to HLO instructions. - void Visit(const ComputationDataHandle& handle, - std::unordered_map* instructions); - - // Resolves a ComputationHandle and Version to a previously lowered - // HloComputation using the hlo_resolver_ function. - HloComputation* ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version); - - // This function takes an input value which is being implicitly broadcast into - // an output shape and figures out the right kBroadcast instruction(s) - // necessary to replicate the implicit broadcast semantics explicitly. - HloInstruction* ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape); - - HloComputation::Builder hlo_builder_; - const SessionComputation& session_computation_; - const VersionedComputationHandle::Version version_; - const UserComputation::HloComputationResolver hlo_resolver_; - const DebugOptions& debug_options_; - const bool include_unreachable_instructions_; -}; - -// Calls 'apply' on each operand of 'request'. -static void ForEachOperand( - const OperationRequest& request, - const std::function& apply) { - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - for (const ComputationDataHandle& param : rng_request.parameter()) { - apply(param); - } - break; - } - - case OpRequest::kConstantRequest: - break; - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - apply(get_tuple_element_request.operand()); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - apply(slice_request.operand()); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - apply(dynamic_slice_request.operand()); - apply(dynamic_slice_request.start_indices()); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - apply(dynamic_update_slice_request.operand()); - apply(dynamic_update_slice_request.update()); - apply(dynamic_update_slice_request.start_indices()); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - apply(convolve_request.lhs()); - apply(convolve_request.rhs()); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - apply(fft_request.operand()); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - - apply(batch_norm_training_request.operand()); - apply(batch_norm_training_request.scale()); - apply(batch_norm_training_request.offset()); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - - apply(batch_norm_inference_request.operand()); - apply(batch_norm_inference_request.scale()); - apply(batch_norm_inference_request.offset()); - apply(batch_norm_inference_request.mean()); - apply(batch_norm_inference_request.variance()); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - apply(batch_norm_grad_request.operand()); - apply(batch_norm_grad_request.scale()); - apply(batch_norm_grad_request.mean()); - apply(batch_norm_grad_request.variance()); - apply(batch_norm_grad_request.grad_output()); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - apply(cross_replica_sum_request.operand()); - break; - } - - case OpRequest::kInfeedRequest: - break; - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - apply(outfeed_request.operand()); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - apply(reduce_request.operand()); - apply(reduce_request.init_value()); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - apply(reduce_window_request.operand()); - apply(reduce_window_request.init_value()); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - apply(select_and_scatter_request.operand()); - apply(select_and_scatter_request.source()); - apply(select_and_scatter_request.init_value()); - - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - apply(broadcast_request.operand()); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - apply(reshape_request.operand()); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - apply(transpose_request.operand()); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - apply(reverse_request.operand()); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - apply(pad_request.operand()); - apply(pad_request.padding_value()); - break; - } - - case OpRequest::kRecvRequest: - case OpRequest::kParameterRequest: - break; - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - apply(while_request.init()); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - apply(conditional_request.predicate()); - apply(conditional_request.true_operand()); - apply(conditional_request.false_operand()); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - apply(ternary_op_request.lhs()); - apply(ternary_op_request.rhs()); - apply(ternary_op_request.ehs()); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - for (const ComputationDataHandle& operand : cc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& hc_request = - request.request().host_compute_request(); - for (const ComputationDataHandle& operand : hc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - apply(dot_request.rhs()); - apply(dot_request.lhs()); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - apply(unary_op_request.operand()); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - apply(binary_op_request.rhs()); - apply(binary_op_request.lhs()); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - apply(reduce_precision_request.operand()); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - apply(trace_request.operand()); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - apply(send_request.operand()); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - apply(gather_request.input()); - apply(gather_request.gather_indices()); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } -} - -void ComputationLowerer::TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit) { - // Stack containing {handle, enter} pairs. The 'enter' value describes whether - // we are entering or leaving 'handle'. - std::stack> work; - work.push({root, true}); - while (!work.empty()) { - ComputationDataHandle handle; - bool enter; - std::tie(handle, enter) = work.top(); - work.pop(); - - if (enter) { - // We are entering 'handle'. The first time we enter 'handle', we add it - // to 'visited' with a nullptr value. If 'handle' is already in 'visited', - // we do not visit it again. This algorithm only uses the presence of - // a handle in 'visited', but we use a map so we can use the same data - // structure to store the HloInstruction outputs. - if (visited->emplace(handle.handle(), nullptr).second) { - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - // Push the corresponding 'leave' action onto the stack, followed by - // the operands. - work.push({handle, false}); - ForEachOperand(request, [&work](const ComputationDataHandle& child) { - work.push({child, true}); - }); - } - } else { - // We are leaving 'handle'. We have visited the operands of 'handle', and - // now can visit the 'handle' itself. - visit(handle); - } - } -} - -StatusOr> ComputationLowerer::Lower() { - // Map from ComputationDataHandle to HLO instruction. Serves as a record of - // which operations have been visited as well as a cache for looking up - // ComputationDataHandles as HloInstructions. - std::unordered_map instructions; - - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version_, session_computation_)); - - auto visit = [&](const ComputationDataHandle& handle) { - Visit(handle, &instructions); - }; - TraversePostorder(root_request->output_handle(), &instructions, visit); - HloInstruction* hlo_root = - instructions.at(root_request->output_handle().handle()); - - if (include_unreachable_instructions_) { - // Iterate through all computation data handles, and visit any unvisited - // operations. - for (int64 request_num = 1; request_num <= version_; ++request_num) { - TF_ASSIGN_OR_RETURN(const OperationRequest* request, - LookUpRequest(request_num, session_computation_)); - TraversePostorder(request->output_handle(), &instructions, visit); - } - } - - return hlo_builder_.Build(hlo_root); -} - -HloComputation* ComputationLowerer::ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version) { - const VersionedComputationHandle checked_handle = {handle, version}; - return hlo_resolver_(checked_handle); -} - -HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape) { - auto fadd = [this](std::unique_ptr x) { - return hlo_builder_.AddInstruction(std::move(x)); - }; - return fadd( - HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd)); -} - -void ComputationLowerer::Visit( - const ComputationDataHandle& handle, - std::unordered_map* instructions) { - CHECK_LE(handle.handle(), version_); - CHECK(instructions->at(handle.handle()) == nullptr); - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - auto add_instruction = [&](std::unique_ptr instruction) { - HloInstruction* hlo_instruction = - hlo_builder_.AddInstruction(std::move(instruction)); - hlo_instruction->set_metadata(request.request().metadata()); - if (request.request().has_sharding()) { - OpSharding op_sharding = request.request().sharding(); - hlo_instruction->set_sharding( - HloSharding::FromProto(op_sharding).ValueOrDie()); - } - return hlo_instruction; - }; - auto lookup_instruction = [&](const ComputationDataHandle& handle) { - return instructions->at(handle.handle()); - }; - HloInstruction* hlo_instruction; - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - std::vector parameters; - for (const ComputationDataHandle& param : rng_request.parameter()) { - parameters.push_back(lookup_instruction(param)); - } - hlo_instruction = add_instruction(HloInstruction::CreateRng( - request.output_shape(), rng_request.distribution(), parameters)); - break; - } - - case OpRequest::kConstantRequest: { - const ConstantRequest& constant_request = - request.request().constant_request(); - hlo_instruction = add_instruction(HloInstruction::CreateConstant( - Literal::CreateFromProto(constant_request.literal()) - .ConsumeValueOrDie())); - break; - } - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - HloInstruction* operand = - lookup_instruction(get_tuple_element_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement( - request.output_shape(), operand, get_tuple_element_request.index())); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - HloInstruction* operand = lookup_instruction(slice_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateSlice( - request.output_shape(), operand, - AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_slice_request.operand()); - HloInstruction* start_indices = - lookup_instruction(dynamic_slice_request.start_indices()); - - hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice( - request.output_shape(), operand, start_indices, - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_update_slice_request.operand()); - HloInstruction* update = - lookup_instruction(dynamic_update_slice_request.update()); - HloInstruction* start_indices = - lookup_instruction(dynamic_update_slice_request.start_indices()); - hlo_instruction = - add_instruction(HloInstruction::CreateDynamicUpdateSlice( - request.output_shape(), operand, update, start_indices)); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - hlo_instruction = add_instruction(HloInstruction::CreateConcatenate( - request.output_shape(), operands, concatenate_request.dimension())); - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - HloInstruction* lhs = lookup_instruction(convolve_request.lhs()); - HloInstruction* rhs = lookup_instruction(convolve_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateConvolve( - request.output_shape(), lhs, rhs, convolve_request.window(), - convolve_request.dimension_numbers())); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - HloInstruction* operand = lookup_instruction(fft_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateFft( - request.output_shape(), operand, fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - HloInstruction* lhs = lookup_instruction(dot_request.lhs()); - HloInstruction* rhs = lookup_instruction(dot_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateDot( - request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - HloInstruction* operand = - lookup_instruction(cross_replica_sum_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), {operand})); - break; - } - - case OpRequest::kInfeedRequest: { - const InfeedRequest& infeed_request = request.request().infeed_request(); - hlo_instruction = add_instruction(HloInstruction::CreateInfeed( - request.output_shape(), infeed_request.config())); - break; - } - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - HloInstruction* operand = lookup_instruction(outfeed_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateOutfeed( - outfeed_request.shape(), operand, outfeed_request.outfeed_config())); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - std::vector operands; - for (const ComputationDataHandle& handle : map_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version map_version = - request.embedded_computation_versions(0); - HloComputation* map_computation = - ResolveComputation(map_request.to_apply(), map_version); - hlo_instruction = add_instruction(HloInstruction::CreateMap( - request.output_shape(), operands, map_computation)); - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - HloInstruction* operand = lookup_instruction(reduce_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_version = - request.embedded_computation_versions(0); - HloComputation* reduce_computation = - ResolveComputation(reduce_request.to_apply(), reduce_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduce( - request.output_shape(), operand, init_value, - AsInt64Slice(reduce_request.dimensions()), reduce_computation)); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - HloInstruction* operand = - lookup_instruction(reduce_window_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_window_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_window_version = - request.embedded_computation_versions(0); - HloComputation* reduce_window_computation = ResolveComputation( - reduce_window_request.to_apply(), reduce_window_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow( - request.output_shape(), operand, init_value, - reduce_window_request.window(), reduce_window_computation)); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - HloInstruction* operand = - lookup_instruction(select_and_scatter_request.operand()); - HloInstruction* source = - lookup_instruction(select_and_scatter_request.source()); - HloInstruction* init_value = - lookup_instruction(select_and_scatter_request.init_value()); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version select_version = - request.embedded_computation_versions(0); - VersionedComputationHandle::Version scatter_version = - request.embedded_computation_versions(1); - HloComputation* select_computation = ResolveComputation( - select_and_scatter_request.select(), select_version); - HloComputation* scatter_computation = ResolveComputation( - select_and_scatter_request.scatter(), scatter_version); - hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter( - request.output_shape(), operand, select_computation, - select_and_scatter_request.window(), source, init_value, - scatter_computation)); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_training_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_training_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_training_request.offset()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining( - request.output_shape(), operand, scale, offset, - batch_norm_training_request.epsilon(), - batch_norm_training_request.feature_index())); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_inference_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_inference_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_inference_request.offset()); - HloInstruction* mean = - lookup_instruction(batch_norm_inference_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_inference_request.variance()); - - hlo_instruction = - add_instruction(HloInstruction::CreateBatchNormInference( - request.output_shape(), operand, scale, offset, mean, variance, - batch_norm_inference_request.epsilon(), - batch_norm_inference_request.feature_index())); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - HloInstruction* operand = - lookup_instruction(batch_norm_grad_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_grad_request.scale()); - HloInstruction* mean = lookup_instruction(batch_norm_grad_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_grad_request.variance()); - HloInstruction* grad_output = - lookup_instruction(batch_norm_grad_request.grad_output()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormGrad( - request.output_shape(), operand, scale, mean, variance, grad_output, - batch_norm_grad_request.epsilon(), - batch_norm_grad_request.feature_index())); - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - HloInstruction* operand = lookup_instruction(broadcast_request.operand()); - std::vector broadcast_dimensions; - // The client-level broadcast instruction just appends dimensions on the - // left (adds lowest numbered dimensions). The HLO broadcast op is more - // flexible and can add new dimensions anywhere. The broadcast_dimensions - // maps operand dimensions to dimensions in the broadcast output, so - // to append dimensions on the left the broadcast_dimensions should just - // be the n highest dimension numbers of the output shape where n is - // the number of input dimensions. - broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape())); - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { - broadcast_dimensions.push_back(i + - ShapeUtil::Rank(request.output_shape()) - - ShapeUtil::Rank(operand->shape())); - } - hlo_instruction = add_instruction(HloInstruction::CreateBroadcast( - request.output_shape(), operand, broadcast_dimensions)); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - HloInstruction* operand = lookup_instruction(reshape_request.operand()); - HloInstruction* transposed; - if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) { - transposed = operand; - } else { - transposed = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(reshape_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(reshape_request.dimensions()))); - } - hlo_instruction = add_instruction( - HloInstruction::CreateReshape(request.output_shape(), transposed)); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - HloInstruction* operand = lookup_instruction(transpose_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(transpose_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(transpose_request.dimensions()))); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - HloInstruction* operand = lookup_instruction(reverse_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateReverse( - request.output_shape(), operand, - AsInt64Slice(reverse_request.dimensions()))); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - HloInstruction* operand = lookup_instruction(pad_request.operand()); - HloInstruction* padding_value = - lookup_instruction(pad_request.padding_value()); - hlo_instruction = add_instruction(HloInstruction::CreatePad( - request.output_shape(), operand, padding_value, - pad_request.padding_config())); - break; - } - - case OpRequest::kRecvRequest: { - const RecvRequest& recv_request = request.request().recv_request(); - HloInstruction* recv = add_instruction(HloInstruction::CreateRecv( - request.output_shape(), recv_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv)); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - hlo_instruction = add_instruction(HloInstruction::CreateParameter( - parameter_request.parameter(), request.output_shape(), - parameter_request.name())); - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateConvert(request.output_shape(), operand)); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert( - request.output_shape(), operand)); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version condition_version = - request.embedded_computation_versions(0); - HloComputation* condition = - ResolveComputation(while_request.condition(), condition_version); - VersionedComputationHandle::Version body_version = - request.embedded_computation_versions(1); - HloComputation* body = - ResolveComputation(while_request.body(), body_version); - HloInstruction* init = lookup_instruction(while_request.init()); - hlo_instruction = add_instruction(HloInstruction::CreateWhile( - request.output_shape(), condition, body, init)); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version true_computation_version = - request.embedded_computation_versions(0); - HloComputation* true_computation = ResolveComputation( - conditional_request.true_computation(), true_computation_version); - VersionedComputationHandle::Version false_computation_version = - request.embedded_computation_versions(1); - HloComputation* false_computation = ResolveComputation( - conditional_request.false_computation(), false_computation_version); - HloInstruction* predicate = - lookup_instruction(conditional_request.predicate()); - HloInstruction* true_operand = - lookup_instruction(conditional_request.true_operand()); - HloInstruction* false_operand = - lookup_instruction(conditional_request.false_operand()); - hlo_instruction = add_instruction(HloInstruction::CreateConditional( - request.output_shape(), predicate, true_operand, true_computation, - false_operand, false_computation)); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - HloInstruction* lhs = lookup_instruction(ternary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); - HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); - auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); - if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && - !ShapeUtil::IsTuple(request.output_shape())) { - if (!ShapeUtil::IsTuple(lhs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::IsTuple(rhs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - - if (!ShapeUtil::IsTuple(ehs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { - ehs = - ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); - } - } - - hlo_instruction = add_instruction(HloInstruction::CreateTernary( - request.output_shape(), hlo_opcode, lhs, rhs, ehs)); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - auto hlo_opcode = - VariadicOperationToHloOpcode(variadic_op_request.varop()); - hlo_instruction = add_instruction(HloInstruction::CreateVariadic( - request.output_shape(), hlo_opcode, operands)); - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - std::vector operands; - for (const ComputationDataHandle& handle : call_request.operands()) { - operands.push_back(lookup_instruction(handle)); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version call_version = - request.embedded_computation_versions(0); - HloComputation* call_computation = - ResolveComputation(call_request.to_apply(), call_version); - hlo_instruction = add_instruction(HloInstruction::CreateCall( - request.output_shape(), operands, call_computation)); - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - std::vector operands; - for (const ComputationDataHandle& operand : cc_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - hlo_instruction = add_instruction(HloInstruction::CreateCustomCall( - cc_request.shape(), operands, cc_request.call_target_name())); - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& host_compute_request = - request.request().host_compute_request(); - std::vector operands; - for (const ComputationDataHandle& operand : - host_compute_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - auto output_shape = host_compute_request.shape(); - auto channel_name = host_compute_request.channel_name(); - auto cost_estimate_ns = host_compute_request.cost_estimate_ns(); - hlo_instruction = add_instruction(HloInstruction::CreateHostCompute( - output_shape, operands, channel_name, cost_estimate_ns)); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - HloInstruction* operand = lookup_instruction(unary_op_request.operand()); - auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); - hlo_instruction = add_instruction(HloInstruction::CreateUnary( - request.output_shape(), hlo_opcode, operand)); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - HloInstruction* lhs = lookup_instruction(binary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(binary_op_request.rhs()); - auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop()); - if (binary_op_request.broadcast_dimensions_size() > 0 && - ShapeUtil::Rank(lhs->shape()) != ShapeUtil::Rank(rhs->shape())) { - // Emit a broadcast instruction to perform the "broadcast in dimension" - // operation. - HloInstruction* operand_to_broadcast = - ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs - : rhs; - CHECK_EQ(ShapeUtil::Rank(operand_to_broadcast->shape()), - binary_op_request.broadcast_dimensions().size()); - - // Construct the bounds of the shape of the kBroadcast instruction - // responsible for the in-dimension broadcast. - std::vector output_dimensions; - for (int64 size : request.output_shape().dimensions()) { - output_dimensions.push_back(size); - } - for (int64 operand_dim = 0; - operand_dim < ShapeUtil::Rank(operand_to_broadcast->shape()); - ++operand_dim) { - int64 output_dim = - binary_op_request.broadcast_dimensions()[operand_dim]; - output_dimensions[output_dim] = - operand_to_broadcast->shape().dimensions(operand_dim); - } - - Shape broadcast_shape = ShapeUtil::MakeShape( - operand_to_broadcast->shape().element_type(), output_dimensions); - - // The broadcast semantics of a client-level binary op broadcast is - // identical to the HLO broadcast semantics so the broadcast_dimensions - // field can just be passed to the instruction builder. - HloInstruction* broadcasted_operand = - add_instruction(HloInstruction::CreateBroadcast( - broadcast_shape, operand_to_broadcast, - AsInt64Slice(binary_op_request.broadcast_dimensions()))); - - lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; - rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; - } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { - if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - } - hlo_instruction = add_instruction(HloInstruction::CreateBinary( - request.output_shape(), hlo_opcode, lhs, rhs)); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - HloInstruction* operand = - lookup_instruction(reduce_precision_request.operand()); - auto exponent_bits = reduce_precision_request.exponent_bits(); - auto mantissa_bits = reduce_precision_request.mantissa_bits(); - hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision( - request.output_shape(), operand, exponent_bits, mantissa_bits)); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - HloInstruction* operand = lookup_instruction(trace_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateTrace(trace_request.tag(), operand)); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - HloInstruction* operand = lookup_instruction(send_request.operand()); - HloInstruction* send = add_instruction(HloInstruction::CreateSend( - operand, send_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send)); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - HloInstruction* input_operand = - lookup_instruction(gather_request.input()); - HloInstruction* gather_indices_operand = - lookup_instruction(gather_request.gather_indices()); - std::vector window_bounds; - c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds)); - hlo_instruction = add_instruction(HloInstruction::CreateGather( - request.output_shape(), input_operand, gather_indices_operand, - gather_request.dimension_numbers(), window_bounds)); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - (*instructions)[handle.handle()] = hlo_instruction; -} // NOLINT(readability/fn_size) - -} // namespace - -StatusOr> UserComputation::BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(2) << "Building HloComputation from UserComputation " << name_ - << " at version " << version; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - ComputationLowerer::Lower( - tensorflow::strings::StrCat(name(), ".v", version), - session_computation_, version, std::move(hlo_resolver), debug_options, - include_unreachable_instructions)); - - return std::move(hlo_computation); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h deleted file mode 100644 index 5544c868fe905c1ca7e6cab32738440add2e3b4f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation.h +++ /dev/null @@ -1,413 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// A UserComputation is the built-up computation that users create via the -// XLA Service interface. -// -// The XLA service adds instructions to a user computation via this -// interface. The state of the computation is stored as a SessionComputation -// proto which holds a record of all operation-building requests received by the -// XLA service. -// -// UserComputations are lowered to HloComputations which are passed to the high -// level compiler interface. -class UserComputation { - public: - // Factory used when restoring a computation from serialized session - // computation (computation snapshot) data. Remaps any references to - // computation handle via the old_to_new mapping. - // - // An error will occur if the old_to_new mapping cannot resolve a reference to - // a computation that is present in session_computation. - static StatusOr> MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new); - - // Creates an empty computation with the given name and computation handle. - explicit UserComputation(const string& name, const ComputationHandle& handle); - - // Enqueues a parameter-retrieving instruction onto this user computation. - // Returns an error status if the parameter number is already registered with - // different values. - StatusOr AddParameterInstruction( - const ParameterRequest& parameter_request); - - // Enqueues a pad instruction onto this user computation. - StatusOr AddPadInstruction( - const PadRequest& pad_request); - - // Enqueues a tracing instruction onto this user computation. - // Returns an error status if the operand cannot be resolved. - Status AddTraceInstruction(const TraceRequest& trace_request); - - // Enqueues a random number generation instruction onto this user computation. - StatusOr AddRngInstruction( - const RngRequest& rng_request); - - // Enqueues a unary instruction onto this user computation. - // Returns an error status if the operand index is out of bounds. - StatusOr AddUnaryInstruction( - const UnaryOpRequest& unary_request); - - // Enqueues a batch norm training instruction onto this user computation. - StatusOr AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request); - - // Enqueues a batch norm inference instruction onto this user computation. - StatusOr AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request); - - // Enqueues a batch norm grad instruction onto this user computation. - StatusOr AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request); - - // Enqueues a binary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddBinaryInstruction( - const BinaryOpRequest& binary_request); - - // Enqueues a ternary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddTernaryInstruction( - const TernaryOpRequest& ternary_request); - - // Enqueues a variadic instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddVariadicInstruction( - const VariadicOpRequest& variadic_request); - - // Enqueues a constant instruction onto this user computation. - StatusOr AddConstantInstruction( - const ConstantRequest& constant_request); - - // Enqueues a get tuple element instruction onto this user computation. - StatusOr AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request); - - // Enqueues a map instruction onto this user computation. - StatusOr AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation); - - // Enqueues a reduce-precision instruction onto this user computation. - StatusOr AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request); - - // Enqueues a convolution instruction onto this user computation. - StatusOr AddConvolveInstruction( - const ConvolveRequest& convolve_request); - - // Enqueues an FFT instruction onto this user computation. - StatusOr AddFftInstruction( - const FftRequest& fft_request); - - // Enqueues a cross replica sum instruction onto this user computation. - StatusOr AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request); - - // Enqueues an infeed instruction onto this user computation. - StatusOr AddInfeedInstruction( - const InfeedRequest& infeed_request); - - // Enqueues an outfeed instruction onto this user computation. - StatusOr AddOutfeedInstruction( - const OutfeedRequest& outfeed_request); - - // Enqueues a host compute instruction onto this user computation. - StatusOr AddHostComputeInstruction( - const HostComputeRequest& host_compute_request); - - // Enqueues a call instruction onto this user computation. - StatusOr AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation); - - // Enqueues a custom call instruction onto this user computation. - StatusOr AddCustomCallInstruction( - const CustomCallRequest& custom_call_request); - - // Enqueues a dot instruction onto this user computation. - StatusOr AddDotInstruction( - const DotRequest& dot_request); - - // Enqueues a broadcast instruction onto this user computation. - StatusOr AddBroadcastInstruction( - const BroadcastRequest& broadcast_request); - - // Enqueues a reshape instruction onto this user computation. - StatusOr AddReshapeInstruction( - const ReshapeRequest& reshape_request); - - // Enqueues a transpose instruction onto this user computation. - StatusOr AddTransposeInstruction( - const TransposeRequest& transpose_request); - - // Enqueues a slice instruction onto this user computation. - StatusOr AddSliceInstruction( - const SliceRequest& slice_request); - - // Enqueues a dynamic slice instruction onto this user computation. - StatusOr AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request); - - // Enqueues a dynamic update slice instruction onto this user computation. - StatusOr AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request); - - // Enqueues a concatenate instruction onto this user computation. - StatusOr AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request); - - // Enqueues a convert instruction onto this user computation. - StatusOr AddConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a bitcast element instruction onto this user computation. - StatusOr AddBitcastConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a reduce instruction onto this user computation. - StatusOr AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation); - - // Enqueues a windowed reduce instruction onto this user computation. - StatusOr AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation); - - // Enqueues a select-and-scatter instruction onto this user - // computation. - StatusOr AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation); - - // Enqueues a reverse instruction onto this user computation. - StatusOr AddReverseInstruction( - const ReverseRequest& reverse_request); - - // Enqueues a while instruction onto this user computation. - StatusOr AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation); - - // Enqueues a conditional instruction on this user computation. - StatusOr AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation); - - // Enqueues a Send instruction onto this user computation. - StatusOr AddSendInstruction( - const SendRequest& send_request); - - // Enqueues a Recv instruction onto this user computation. - StatusOr AddRecvInstruction( - const RecvRequest& recv_request); - - // Enqueues a Gather instruction onto this user computation. - StatusOr AddGatherInstruction( - const GatherRequest& gather_request); - - // Returns the user-provided name of this user computation, which is provided - // via the XLA computation-building API. - const string& name() const { return name_; } - - // Subsequent executions of this computation will compute the value - // represented by handle, rather than the last expression enqueued - // on the computation. - Status SetReturnValue(const ComputationDataHandle& handle); - - // Return a versioned handle for this computation. - VersionedComputationHandle GetVersionedHandle() const; - - // Return a versioned handle for this computation with a version equal to the - // point at which given operation was added to the computation. - VersionedComputationHandle GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const; - - // Return a version value representing the current state of the - // computation. - VersionedComputationHandle::Version version() const; - - // Computes and returns the program shape for the user computation -- gathers - // parameters and result type into a single proto. A shared_ptr is used - // because the returned pointer refers to an internally cached value which may - // be discarded by the UserComputation object. This avoid unnecessary copies. - // - // If the parameter space is not dense (i.e. there are holes in the parameter - // numbers provided) then an error status is returned. - StatusOr> ComputeProgramShape( - VersionedComputationHandle::Version version) const; - - // Returns true if the given data handle does not depend on any parameter with - // index higher then num_parameters. That is, the value can be computed at - // compile time if we know the first num_parameters arguments. - StatusOr IsConstant(const ComputationDataHandle& handle, - int64 num_parameters); - - // Returns the output shape of the operation indicated by the given handle. - StatusOr GetShape(const ComputationDataHandle& handle); - - // Sets metadata on the Hlo instruction referenced by the given handle. - Status SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata); - - // Sets the device assignment on the Hlo instruction referenced by 'handle'. - Status SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding); - - // Builds a HLO computation from the UserComputation. The parameter "resolver" - // is a function which returns a pointer to the HloComputation corresponding - // to the given ComputationHandle at the given version. The resolver is used - // for operations, such as map, which call other computations and need a - // pointer to the called HloComputation to construct the respective HLO - // instructions. If include_unreachable_instructions is true, then - // instructions which are not reachable from the root are lowered into - // HloInstructions. - using HloComputationResolver = - std::function; - StatusOr> BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions = true) const; - - // Return a vector containing the embedded computations used by this - // UserComputation. Only embedded computations which are called directly by - // this UserComputation are included. That is, the transitive closure of - // embedded computations is not included. - std::vector GetEmbeddedComputations( - VersionedComputationHandle::Version version) const; - - // Returns the number of OperationRequest objects in this UserComputation. - // The 'version' of a computation is identical to the number of - // OperationRequests in the UserComputation. - int64 request_count(VersionedComputationHandle::Version version) const { - return version; - } - - // Returns a copy of the internal session state for this computation -- this - // is useful for serializing the guts of a user computation, though references - // to other handles (e.g. referred-to computations) must be handled with care - // in the serialization / de-serialization process. - SessionComputation CloneSessionComputation( - VersionedComputationHandle::Version version) const; - - // Warning: typically we don't want to look up computation data handles until - // the computation is finished being built, for consistency purposes. We - // expose this routine for error reporting purposes so that we can provide - // more meaningful error messages from the XLA service layer. - // - // Returns the operation request that the handle comes from. - StatusOr LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const; - - // Retrieves the parameter metadata for the given parameter number. - // - // If the parameter number is invalid for this computation, nullopt is - // returned. When the return value has_value(), nullptr will never be - // the held value. - tensorflow::gtl::optional ParameterMetadata( - int parameter_number) const; - - private: - // Warning: dangerous mutating operation that doesn't respect versioning. - // This is only used at initialization time when constructing from a - // SessionComputation a la MakeWithRemapping. - // - // Remaps references to old computations (with handle values in the keys of - // old_to_new) to the computation handle given in the values. This is useful - // when loading computations from snapshots, to finish initialization, before - // the user computation is released into the wild. - Status RemapEmbeddedComputations( - const std::map& old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Returns the OperationRequest corresponding to the given handle. - StatusOr LookUpRequest( - const ComputationDataHandle& handle) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Creates a new ComputationDataHandle with the next available handle value. - ComputationDataHandle CreateComputationDataHandle() - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Checks whether the parameter numbers of the parameter operations are - // contiguous starting from zero. Returns appropriate error status if not. - Status CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - VersionedComputationHandle GetVersionedHandleInternal() const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Name of the computation. - string name_; - - mutable tensorflow::mutex mutex_; - - // State of the computation as a record of all operation-building requests. - SessionComputation session_computation_ GUARDED_BY(mutex_); - - // Mapping from parameter number to operation request containing the - // respective ParameterRequest. - std::map parameters_ GUARDED_BY(mutex_); - - // The next ComputationDataHandle value to assign. Handle values are assigned - // sequentially. - int64 next_handle_value_ GUARDED_BY(mutex_); - - // If handle_to_return_.has_handle() then an Execution of this Computation - // will compute the value represented by handle_to_return_, otherwise it will - // compute the value of (next_handle_value_ - 1). - ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_); - - // Memoized ProgramShape and its version. A shared_ptr is used because - // references to this object are returned by ComputeProgramShape. - mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0; - mutable std::shared_ptr program_shape_ GUARDED_BY(mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(UserComputation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc deleted file mode 100644 index 2fa163953f638c0038e9f6bb11ce2a3742e0558c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ /dev/null @@ -1,340 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace op = xla::testing::opcode_matchers; - -namespace xla { -namespace { - -using UserComputationTest = ::testing::Test; - -TEST_F(UserComputationTest, SimpleComputation) { - const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); - const Shape kVectorShape = ShapeUtil::MakeShape(F32, {2}); - - // Build a simple three operation computatation: - // - // %constant = Constant({123, 42}) - // %param = Param(0) - // %outfeed = Outfeed(%constant) - // - // Build the computation at two different versions and check invariants. - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest constant_request; - *constant_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle, - computation.AddConstantInstruction(constant_request)); - - ParameterRequest param_request; - *param_request.mutable_shape() = kScalarShape; - param_request.set_parameter(0); - param_request.set_name("param0"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle, - computation.AddParameterInstruction(param_request)); - OpMetadata metadata; - metadata.set_op_name("meta"); - TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); - - OutfeedRequest outfeed_request; - *outfeed_request.mutable_operand() = constant_handle; - *outfeed_request.mutable_shape() = kVectorShape; - outfeed_request.set_outfeed_config("abc"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle, - computation.AddOutfeedInstruction(outfeed_request)); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - { - // Test the computation at the latest version. In this case, the most - // recently added operation is an outfeed. However, the outfeed is not the - // root because outfeeds cannot be the root of a computation. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Program shape should have a single scalar parameter and scalar - // result. The outfeed instruction should not affect the program shape. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(latest_version.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - DebugOptions())); - // There should be one HloInstruction per UserComputation operation. - EXPECT_EQ(3, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - - { - // Test the computation at the version right after the parameter instruction - // is added. - VersionedComputationHandle version_at_param = - computation.GetVersionedHandleAtOperation(param_handle); - - // Program shape should have a single scalar parameter, and scalar result. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(version_at_param.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // There should be two instructions, one for the constant and one for the - // parameter. The outfeed instruction should not be included. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(version_at_param.version, hlo_resolver, - DebugOptions())); - EXPECT_EQ(2, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - { - // Test the computation at the latest version, but lowered with - // include_unreachable_instructions set to false. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation( - latest_version.version, hlo_resolver, DebugOptions(), - /*include_unreachable_instructions=*/false)); - // There is only one reachable instruction, the parameter. - EXPECT_EQ(1, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(), - "meta"); - } -} - -TEST_F(UserComputationTest, EliminateScalarBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with scalar broadcast. - // - // %a = Constant({123, 42}) - // %b = Constant(1) - // %add = Add(%a, %b) - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest a_request; - *a_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddConstantInstruction(a_request)); - - ConstantRequest b_request; - *b_request.mutable_literal() = Literal::CreateR0(1.0f)->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddConstantInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - // The binary operation has implicit scalar broadcast, should be converted - // to an explicit broadcast intruction and a binary instruction. - EXPECT_EQ(4, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - LOG(INFO) << hlo_computation->root_instruction()->ToString(); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast || - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with degenerate broadcast. - // - // %a = Param({1, 2, 3}); - // %b = Param({1, 2, 1}); - // %add = Add(%a, %b, {}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - const int64 kDevice = 7; - OpSharding sharding; - sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - sharding.add_tile_assignment_dimensions(1); - sharding.add_tile_assignment_devices(kDevice); - - TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // b a - // | | - // reshape | - // | | - // broadcast | - // \ / - // add - EXPECT_EQ(5, hlo_computation->instruction_count()); - ASSERT_THAT( - hlo_computation->root_instruction(), - op::Add(op::Parameter(), op::Broadcast(op::Reshape(op::Parameter())))); - - const HloInstruction* broadcast = - hlo_computation->root_instruction()->operand(1); - EXPECT_TRUE(broadcast->has_sharding()); - - const HloInstruction* reshape = broadcast->operand(0); - EXPECT_TRUE(reshape->has_sharding()); -} - -TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with in-dim broadcast and degenerate broadcast. - // - // %a = Param({2, 3}); - // %b = Param({2, 1, 4}); - // %add = Add(%a, %b, {0, 1}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - add.add_broadcast_dimensions(0); - add.add_broadcast_dimensions(1); - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // The binary operation has in-dim broadcast and degenerate broadcast, should - // first do the in-dim broadcast then convert the degnerate broadcast into a - // reshape and a broadcast. - // - // b a - // | | - // broadcast reshape - // | | - // | broadcast - // \ / - // add - EXPECT_EQ(6, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast && - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.h b/tensorflow/compiler/xla/service/versioned_computation_handle.h deleted file mode 100644 index 5732a56caffa31dde52dff5c2775f9fde0cacfbd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// A data structure encapsulating a ComputationHandle and version value of that -// computation. This object is used to unambiguously refer to a particular -// computation in the service. -struct VersionedComputationHandle { - // A version value unambiguously specifying the state of the computation at a - // particular point in time as it is being built. This value is the - // ComputationDataHandle of the current root instruction. - using Version = int64; - - ComputationHandle handle; - Version version; - - string ToString() const; - bool operator==(const VersionedComputationHandle& other) const { - return (handle.handle() == other.handle.handle()) && - (version == other.version); - } - bool operator<(const VersionedComputationHandle& other) const { - return ((handle.handle() < other.handle.handle()) || - ((handle.handle() == other.handle.handle()) && - (version < other.version))); - } -}; - -std::ostream& operator<<(std::ostream& out, - const VersionedComputationHandle& versioned_handle); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 0d2288d8ea6ebb0ac4ac9468a211b161438fc5f1..393e75803888d8a642881c4d525b170d1e1180ba 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -55,7 +55,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -95,7 +95,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -136,7 +136,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -184,7 +184,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 321fdeb1ea313d2bc00b0210b422f36915f41453..09ddcffb22c2184262adf87d570870ec000c0e6f 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -98,14 +98,17 @@ static void CreateLoopInvariantCopy( // Returns true if `instruction` is worth hoisting only if it lets us hoist some // instruction using it. The rationale is that hoisting these instructions will // prevent simplification and fusion in the while body. -static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { +bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( + const HloInstruction& instruction) { switch (instruction.opcode()) { default: return false; + case HloOpcode::kConstant: + return !hoist_constants_; + case HloOpcode::kBitcast: case HloOpcode::kBroadcast: - case HloOpcode::kConstant: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: @@ -115,7 +118,8 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { } } -static StatusOr TryHoistingInvariantInstructionsFromWhileBody( +StatusOr +WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); @@ -161,12 +165,16 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( } } - if (unhoisted_invariant_instructions.empty()) { + if (unhoisted_invariant_instructions.empty() && !hoist_constants_) { // There are no obviously loop invariant elements in the state being // threaded through the while loop so give up. In theory this precondition // is too strong -- we could have code that e.g. permutes the elements in // the while state but uses a select to pick the same value on every // iteration. + // + // If we were asked to hoist constants, we need to scan the while body for + // constants even if we didn't find any loop invariant values in the while + // state tuple. return false; } @@ -243,6 +251,9 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( } StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; std::vector while_instrs; for (auto* comp : module->computations()) { @@ -270,6 +281,14 @@ StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { TryHoistingInvariantInstructionsFromWhileBody(while_instr)); changed |= result; } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + } + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8c4b765b0003c48cfacb9d28e7c8259ac0927d66..8e6cc8787576e4f041229da5cf8dd2b09194eb2a 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -27,12 +27,28 @@ namespace xla { class WhileLoopInvariantCodeMotion : public HloPassInterface { public: + // If `hoist_constants` is true then constants are always hoisted out of while + // loop bodies. Otherwise they are only hoisted out if they enable other + // non-trivial computations to be hoisted out. + // + // Setting `hoist_constants` to false can be help if LICM is run in the mid + // level HLO pipeline because hoisting constants out of while loop bodies can + // break optimizations like constant folding. + explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false) + : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; tensorflow::StringPiece name() const override { return "while-loop-invariant-code-motion"; } StatusOr Run(HloModule* module) override; + + private: + bool NotWorthHoistingIndividually(const HloInstruction& instruction); + StatusOr TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr); + + bool hoist_constants_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 799340fda905fb7d40b19b4cb79bb0fcb5629fd3..8831c513eee66e36163135b732f833d46cb7eb03 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -438,5 +439,77 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { EXPECT_FALSE(simplified_loop); } +const char* const kConstantHoistingTestCase = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2]{0}) parameter(0) + p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0 + const = f32[2]{0} constant({3, 4}) + add.0 = f32[2]{0} add(p_body.1, const) + ROOT root = (f32[2]{0}) tuple(add.0) +} + +condition { + p_cond = (f32[2]{0}) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2]{0} constant({1, 2}) + while_init = (f32[2]{0}) tuple(const_0) + ROOT while = (f32[2]{0}) while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = module().GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + + // We expect the while body to be the equivalent of: + // + // wide.body { + // wide_param.1 = (f32[2]{0}, f32[2]{0}) parameter(0) + // get-tuple-element.1 = f32[2]{0} get-tuple-element(wide_param.1), index=0 + // tuple.1 = (f32[2]{0}) tuple(get-tuple-element.1) + // get-tuple-element.4 = f32[2]{0} get-tuple-element(tuple.1), index=0 + // get-tuple-element.7 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // add.1 = f32[2]{0} add(get-tuple-element.4, get-tuple-element.7) + // tuple.3 = (f32[2]{0}) tuple(add.1) + // get-tuple-element.8 = f32[2]{0} get-tuple-element(tuple.3), index=0 + // get-tuple-element.9 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // ROOT tuple.4 = (f32[2]{0}, f32[2]{0}) tuple(get-tuple-element.8, + // get-tuple-element.9) + // } + + auto wide_param_1 = op::Parameter(0); + auto get_tuple_element_1 = op::GetTupleElement(wide_param_1, 0); + auto tuple_1 = op::Tuple(get_tuple_element_1); + auto get_tuple_element_4 = op::GetTupleElement(tuple_1, 0); + auto get_tuple_element_7 = op::GetTupleElement(wide_param_1, 1); + auto add_1 = op::Add(get_tuple_element_4, get_tuple_element_7); + auto tuple_3 = op::Tuple(add_1); + auto get_tuple_element_8 = op::GetTupleElement(tuple_3, 0); + auto get_tuple_element_9 = op::GetTupleElement(wide_param_1, 1); + auto tuple_4 = op::Tuple(get_tuple_element_8, get_tuple_element_9); + + EXPECT_THAT(while_body->root_instruction(), tuple_4); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index ed20b36292a7f24385603627d74fc72ba6b3b724..473eab2ea84eb8faf745cbe299bc80bcc1b62a35 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -117,9 +117,13 @@ WhileUtil::MakeInstructionsLiveIn( HloInstruction* new_while = containing_computation->AddInstruction( HloInstruction::CreateWhile(new_while_shape, new_while_condition, new_while_body, new_while_init)); - TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction( - while_instr, TupleUtil::ExtractPrefix( - new_while, while_instr->shape().tuple_shapes_size()))); + + // We want to get rid of the old while instruction even if it has side + // effecting operations so we do a manual HloComputation::RemoveInstruction + // instead of relying on HloComputation::ReplaceInstruction. + TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix( + new_while, while_instr->shape().tuple_shapes_size()))); + TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr)); HloInstruction* while_body_param = new_while_body->parameter_instruction(0); std::vector live_in_instructions; diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 322d27b88cae60cb051f5fafdde70e2aafedbc1e..e67636d80f4b682fe1335eae535fb86105ac082b 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -38,17 +38,21 @@ class WhileUtil { }; // Replaces `while_instr` with a new while instruction that is equivalent to - // `while_instr`, except that it has all of the HLO instructions in + // `while_instr` except that it has all of the HLO instructions in // `instructions` as live-in, loop invariant values. These new live in values // are represented as new elements appended to the parameter of the while // loop, which must be of tuple shape. GetTupleElement instructions computing // each new live in value is returned in the `while_body_live_in_values` // vector. // - // Precondition: `while_instr` must have a tuple shaped state. + // Deletes `while_instr` after replacing it. // - // Every instruction in `instructions` must be contained in the computation - // that contains `while_instr`. + // Preconditions: + // + // `while_instr` must have a tuple shaped state. + // + // Every instruction in `instructions` must be contained in the computation + // that contains `while_instr`. static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, tensorflow::gtl::ArraySlice instructions); diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 974bc542a34d0af6d41ed29f36df87f4c164a360..d79d3297213e832306ea4726483b0f215df0f5d3 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -16,8 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace { @@ -49,7 +50,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); @@ -150,7 +151,7 @@ ENTRY main { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* while_body = module->GetComputationWithName("body"); @@ -163,5 +164,47 @@ ENTRY main { ASSERT_EQ(gte_list.size(), 1); EXPECT_EQ((*gte_list.begin())->name(), "gte.0"); } + +TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) { + const char* const hlo_string = R"( +HloModule WhileWithSideEffects + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT condition = pred[] infeed() +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + to_make_live_in = f32[100] parameter(1) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloComputation* main = module->GetComputationWithName("main"); + HloInstruction* while_instr = main->root_instruction(); + HloInstruction* to_make_live_in = main->parameter_instruction(1); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, + /*instructions=*/{to_make_live_in})); + + auto is_while = [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }; + EXPECT_EQ(c_count_if(main->instructions(), is_while), 1); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 4f64fe8f835017c3c7093988ae947fe21c377406..14c35e7b84f07bebac33a9753ac26a8ee1418f1e 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" namespace xla { @@ -32,99 +32,52 @@ class ServiceInterface { virtual ~ServiceInterface() = default; // TODO(b/31824348): Convert to use StatusOr. - virtual tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, TransferToClientResponse* result) = 0; + virtual Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) = 0; - virtual tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, TransferToServerResponse* result) = 0; + virtual Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) = 0; - virtual tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0; + virtual Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) = 0; - virtual tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) = 0; + virtual Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) = 0; - virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) = 0; + virtual Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) = 0; - virtual tensorflow::Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) = 0; + virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) = 0; - virtual tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) = 0; + virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) = 0; - virtual tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) = 0; + virtual Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) = 0; - virtual tensorflow::Status ExecuteParallel( - const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0; + virtual Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) = 0; - virtual tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, - ExecuteParallelResponse* result) = 0; - - virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) = 0; - - virtual tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0; - - virtual tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0; - - virtual tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0; - - virtual tensorflow::Status GetComputationGraphStats( + virtual Status GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) = 0; - virtual tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) = 0; - - virtual tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) = 0; - - virtual tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) = 0; - - virtual tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0; - - // Methods used by ComputationBuilder. - virtual tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) = 0; - - virtual tensorflow::Status Op(const OpRequest* arg, OpResponse* result) = 0; - - virtual tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) = 0; - - virtual tensorflow::Status SetReturnValue( - const SetReturnValueRequest* arg, SetReturnValueResponse* results) = 0; - - virtual tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) = 0; + virtual Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) = 0; - virtual tensorflow::Status ComputeConstant( - const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0; + virtual Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) = 0; - virtual tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) = 0; + virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) = 0; - // Methods used by Computation. - virtual tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) = 0; + virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) = 0; // Methods used by GlobalData. - virtual tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) = 0; + virtual Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 789eba5780d37e1fd4d80ec881855951c8bba0eb..7ee366b27a82bdbcb7a63a57ea80194db8ca7df4 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -22,24 +22,24 @@ limitations under the License. namespace xla { -tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { +Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { if (!ShapeUtil::Compatible(other_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(other_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } shape_ = other_shape; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { +Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } *to_shape = shape_; - return tensorflow::Status::OK(); + return Status::OK(); } void ShapeLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index a1dce758cd3ab3f204ce330fca2a7d2bdf57a2be..36806da599cc9b27286e67c128bb7f496f29c105 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -40,7 +40,7 @@ class ShapeLayout { // Assigns the layouts in this ShapeLayout to the Layout fields of the given // shape. 'to_shape' and the shape of the ShapeLayout object must be // compatible. - tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; + Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible @@ -49,7 +49,7 @@ class ShapeLayout { // Copies the layout from the given shape into this ShapeLayout. 'other_shape' // must be compatible with the ShapeLayout's shape. - tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); + Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index ffaa40c2d673a2365342371ed8dab59565d1d08f..5b14953ebb243da7b9be6eafd46160db8bc62707 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -42,36 +42,20 @@ namespace internal { template struct ShapeTreeNode { // Data corresponding to this node. - T data; + std::pair data; - // Children of this node. - std::vector> children; + // Children of this node, as indices into the container's nodes_ array. + std::vector children; - ShapeTreeNode() = default; - explicit ShapeTreeNode(const T& data) : data(data) {} - - ShapeTreeNode(const ShapeTreeNode& other) - : data(other.data), children(other.children.size()) { - for (size_t i = 0; i < children.size(); ++i) { - children[i] = ::xla::MakeUnique(*other.children[i]); - } - } - - ShapeTreeNode& operator=(const ShapeTreeNode& other) { - if (this != &other) { - data = other.data; - children.resize(other.children.size()); - for (size_t i = 0; i < children.size(); ++i) { - children[i] = ::xla::MakeUnique(*other.children[i]); - } - } - return *this; - } + explicit ShapeTreeNode(ShapeIndex index) + : ShapeTreeNode(std::move(index), T()) {} + ShapeTreeNode(ShapeIndex index, T data) + : data(std::move(index), std::move(data)) {} }; } // namespace internal -template +template class ShapeTreeIterator; // A ShapeTree is a recursive data structure which mirrors the structure of a @@ -95,10 +79,9 @@ class ShapeTreeIterator; // before its ShapeTree goes away. template class ShapeTree { - friend class ShapeTreeIterator; - friend class ShapeTreeIterator; - public: + using Node = internal::ShapeTreeNode; + // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} @@ -110,30 +93,12 @@ class ShapeTree { // alive longer than this ShapeTree. explicit ShapeTree(Shape shape); explicit ShapeTree(const Shape* shape); + explicit ShapeTree(const std::shared_ptr& shape); // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(Shape shape, const T& init_value); ShapeTree(const Shape* shape, const T& init_value); - - ShapeTree(const ShapeTree& other) { *this = other; } - ShapeTree(ShapeTree&&) = default; - - ShapeTree& operator=(const ShapeTree& other) { - root_ = other.root_; - - // Fix up internal pointer if necessary. - if (other.shape_storage_) { - CHECK_EQ(other.shape_, other.shape_storage_.get()); - shape_storage_.reset(new Shape(*other.shape_)); - shape_ = shape_storage_.get(); - } else { - shape_ = other.shape_; - } - - return *this; - } - - ShapeTree& operator=(ShapeTree&& other) = default; + ShapeTree(const std::shared_ptr& shape, const T& init_value); // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). @@ -161,63 +126,70 @@ class ShapeTree { return Lookup(index)->children.empty(); } - // iterator implements a forward_iterator with value_type = - // std::pair - using iterator = ShapeTreeIterator; - using const_iterator = ShapeTreeIterator; + ShapeTree(const ShapeTree&) = default; + ShapeTree& operator=(const ShapeTree&) = default; + ShapeTree(ShapeTree&&) = default; + ShapeTree& operator=(ShapeTree&& other) = default; + + // iterator implements a bidirectional_iterator with + // value_type = std::pair. + // + // The iteration order is guaranteed to be a pre-order walk of the ShapeTree. + using iterator = + ShapeTreeIterator, typename std::vector::iterator, + std::pair>; + using const_iterator = + ShapeTreeIterator, + typename std::vector::const_iterator, + const std::pair>; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; // begin/end for iterating over all nodes. iterator begin() { - return iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/false); } iterator end() { - return iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/false); } const_iterator begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/false); } const_iterator end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/false); } // rbegin/rend for iterating over all nodes in reverse. - iterator rbegin() { - return iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/true); - } - iterator rend() { - return iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/true); + reverse_iterator rbegin() { return reverse_iterator(end()); } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); } - const_iterator rbegin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/true); - } - const_iterator rend() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/true); + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); } // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). iterator leaf_begin() { - return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/false); + return iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/true); } iterator leaf_end() { - return iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/true); } const_iterator leaf_begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/true); } const_iterator leaf_end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/true); } // range-based iterator for leaf_begin()/leaf_end(). tensorflow::gtl::iterator_range leaves() { @@ -227,22 +199,32 @@ class ShapeTree { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } - iterator leaf_rbegin() { - return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/true); + reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); } + reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); } + const_reverse_iterator leaf_rbegin() const { + return const_reverse_iterator(leaf_end()); } - iterator leaf_rend() { - return iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/true); + const_reverse_iterator leaf_rend() const { + return const_reverse_iterator(leaf_begin()); } - const_iterator leaf_rbegin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true, - /*reverse=*/true); + + // Returns an iterator pointing to the given ShapeIndex. + // REQUIRES: index must exist in the ShapeTree. + iterator find(const ShapeIndex& index) { + Node* element = Lookup(index); + return iterator(&nodes_, typename std::vector::iterator(element), + /*iterate_leaves_only=*/false); } - const_iterator leaf_rend() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/true); + const_iterator find(const ShapeIndex& index) const { + Node* element = Lookup(index); + return iterator(&nodes_, + typename std::vector::const_iterator(element), + /*iterate_leaves_only=*/false); } + // Returns the number of leaf nodes in the tree. + int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); } + // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // @@ -282,8 +264,6 @@ class ShapeTree { bool operator!=(const ShapeTree& other) const { return !(*this == other); } private: - using Node = internal::ShapeTreeNode; - // Initialize node->children based on 'shape'. All children are assigned the // the given 'init_value'. void InitChildren(const Shape& shape, const T& init_value, Node* node); @@ -292,136 +272,57 @@ class ShapeTree { // default-constructed data values. void InitChildren(const Shape& shape, Node* node); + // Returns the number of subshapes, including interior nodes, in shape. + int64 CountSubshapes(const Shape& shape); + // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). template - static Status ForEachHelper(const Fn& func, const Node& node, - ShapeIndex* index); + static Status ForEachHelper(const Fn& func, const std::vector& nodes); template - static Status ForEachMutableHelper(const Fn& func, Node* node, - ShapeIndex* index); + static Status ForEachMutableHelper(const Fn& func, std::vector* nodes); // Return the tree node at the given index. Node* Lookup(const ShapeIndex& index); const Node* Lookup(const ShapeIndex& index) const; - // The root node, which contains all other nodes. - Node root_; + // The nodes in this shape tree. + std::vector nodes_; // If we own our Shape, this field contains it, and shape_ is a pointer into // here. Otherwise if we don't own our shape, this is nullptr. - std::unique_ptr shape_storage_; + std::shared_ptr shape_storage_; // The XLA shape mirrored in this ShapeTree. This is either // shape_storage_.get() or the Shape pointer passed to our constructor. const Shape* shape_; }; -// Internal iterator that performs a pre-order walk. This is copyable, but -// contains a vector so isn't cheap to copy. This also means post-increment is -// expensive. The iterator value_type is equivalent to a std::pair, similar to std::map. The non-const iterator's T& type can be mutated -// in-place. -template -class ShapeTreeIterator : public std::iterator> { +// Internal iterator that performs a pre-order walk. This is cheap to copy. +// The iterator value_type is equivalent to a +// std::pair&, similar to std::map. +template +class ShapeTreeIterator + : public std::iterator { public: - using value_type = - typename std::conditional, - std::pair>::type; - using NodeType = - typename std::conditional::Node, - typename ShapeTree::Node>::type; - - // Construct an iterator pointing at node. Node must either be the tree root - // or nullptr (which is equivalent to end() and should not be dereferenced or - // incremented). If iterate_leaves_only is true, the iterator will not include - // interior tree nodes, only leaves. If reverse is true, the iterator will - // visit nodes in the reverse of pre-order traversal. - ShapeTreeIterator(NodeType* node, bool iterate_leaves_only, bool reverse) - : node_(node), - iterate_leaves_only_(iterate_leaves_only), - reverse_(reverse) { - if (node_) { - if (reverse_) { - while (!node_->children.empty()) { - const int child_index = node_->children.size() - 1; - stack_.push_back({node_, child_index}); - node_ = node_->children[child_index].get(); - } - } else { - if (!node_->children.empty() && iterate_leaves_only) { - ++*this; - } - } + ShapeTreeIterator(ContainerType* nodes, IteratorType node, + bool iterate_leaves_only) + : nodes_(nodes), + node_(std::move(node)), + iterate_leaves_only_(iterate_leaves_only) { + while (iterate_leaves_only && node_ != nodes_->end() && + !node_->children.empty()) { + ++node_; } } - ShapeTreeIterator(const ShapeTreeIterator& other) - : node_(other.node_), - stack_(other.stack_), - iterate_leaves_only_(other.iterate_leaves_only_), - reverse_(other.reverse_) {} ShapeTreeIterator& operator++() { - CHECK_NE(nullptr, node_) << "walking off the end() of an iterator!"; - if (reverse_) { - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second - 1; - stack_.pop_back(); - if (next_child_index < 0) { - if (!iterate_leaves_only_) { - // All children are visited, yield . - return *this; - } - } else { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - while (!node_->children.empty()) { - const int child_index = node_->children.size() - 1; - stack_.push_back({node_, child_index}); - node_ = node_->children[child_index].get(); - } - return *this; - } - } - } else { - // We're doing a pre-order walk, so if our current node has children take - // the first child. - if (!node_->children.empty()) { - stack_.push_back({node_, /*child-index=*/0}); - node_ = node_->children[0].get(); - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); - } - } - // Otherwise we are currently at a leaf. Walk back up until a node - // contains a child we haven't visited yet. - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second + 1; - stack_.pop_back(); - if (node_->children.size() > next_child_index) { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); - } - } - } + ++node_; + while (iterate_leaves_only_ && node_ != nodes_->end() && + !node_->children.empty()) { + ++node_; } - // We've walked off the end of the tree. Set node_ to nullptr to signify - // end(). - node_ = nullptr; - current_.reset(); return *this; } ShapeTreeIterator operator++(int) { @@ -429,52 +330,62 @@ class ShapeTreeIterator : public std::iterator nodes_->begin() && + !node_->children.empty()) { + --node_; + } + return *this; + } + ShapeTreeIterator operator--(int) { + auto i = *this; + --(*this); + return i; + } + bool operator==(const ShapeTreeIterator& other) const { return node_ == other.node_; } bool operator!=(const ShapeTreeIterator& other) const { return node_ != other.node_; } - value_type& operator*() { return UpdateCurrent(); } - value_type* operator->() { return &UpdateCurrent(); } + ValueType& operator*() { return node_->data; } + ValueType* operator->() { return &node_->data; } private: - // Updates the current_ member to reflect the current state. - value_type& UpdateCurrent() { - ShapeIndex index; - for (auto& node_and_index : stack_) { - index.push_back(node_and_index.second); - } - current_ = ::xla::MakeUnique(index, node_->data); - return *current_; - } - - // The node to which this iterator is pointing. This is the source of truth in - // the iterator - the stack only exists to facilitate walking back from - // children to parents. - NodeType* node_; - // Stack of {node, child-index} pairs of the path taken from the root to get - // to node_. This allows us to backtrack and know where to go next. - std::vector> stack_; + ContainerType* nodes_; + IteratorType node_; // True if we should not include interior nodes in our walk. bool iterate_leaves_only_; - // True if we should yield the reverse of the pre-order traversal. - bool reverse_; - // Placeholder for the current value. Ideally this wouldn't exist and would - // just be an rvalue, but operator -> needs to return a pointer to something. - // We cannot just use a plain old value_type as it contains a reference so - // cannot be default-constructed. - std::unique_ptr current_; }; +template +int64 ShapeTree::CountSubshapes(const Shape& shape) { + int64 current_count = 1; + if (ShapeUtil::IsTuple(shape)) { + int64 count = ShapeUtil::TupleElementCount(shape); + for (int i = 0; i < count; ++i) { + current_count += CountSubshapes(shape.tuple_shapes(i)); + } + } + return current_count; +} + template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, Node* node) { if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - node->children.emplace_back(new Node(init_value)); - InitChildren(shape.tuple_shapes(i), init_value, - node->children.back().get()); + const int64 size = ShapeUtil::TupleElementCount(shape); + node->children.reserve(size); + ShapeIndex shape_index = node->data.first; + shape_index.push_back(0); + for (int i = 0; i < size; ++i) { + shape_index[shape_index.size() - 1] = i; + node->children.push_back(nodes_.size()); + nodes_.emplace_back(shape_index, init_value); + InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back()); } } } @@ -482,63 +393,92 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, template void ShapeTree::InitChildren(const Shape& shape, Node* node) { if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - node->children.emplace_back(new Node()); - InitChildren(shape.tuple_shapes(i), node->children.back().get()); + const int64 size = ShapeUtil::TupleElementCount(shape); + node->children.reserve(size); + ShapeIndex shape_index = node->data.first; + shape_index.push_back(0); + for (int i = 0; i < size; ++i) { + shape_index[shape_index.size() - 1] = i; + node->children.push_back(nodes_.size()); + nodes_.emplace_back(shape_index); + InitChildren(shape.tuple_shapes(i), &nodes_.back()); } } } template ShapeTree::ShapeTree(Shape shape) - : root_(), - shape_storage_(::xla::MakeUnique(std::move(shape))), + : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - InitChildren(*shape_, &root_); + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); +} + +template +ShapeTree::ShapeTree(const Shape* shape) : shape_(shape) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); } template -ShapeTree::ShapeTree(const Shape* shape) : root_(), shape_(shape) { - InitChildren(*shape_, &root_); +ShapeTree::ShapeTree(const std::shared_ptr& shape) + : shape_storage_(shape), shape_(shape_storage_.get()) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); } template ShapeTree::ShapeTree(Shape shape, const T& init_value) - : root_(init_value), - shape_storage_(::xla::MakeUnique(std::move(shape))), + : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - InitChildren(*shape_, init_value, &root_); + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); } template ShapeTree::ShapeTree(const Shape* shape, const T& init_value) - : root_(init_value), shape_(shape) { - InitChildren(*shape_, init_value, &root_); + : shape_(shape) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); +} + +template +ShapeTree::ShapeTree(const std::shared_ptr& shape, + const T& init_value) + : shape_storage_(shape), shape_(shape_storage_.get()) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); } template const T& ShapeTree::element(const ShapeIndex& index) const { - return Lookup(index)->data; + return Lookup(index)->data.second; } template T* ShapeTree::mutable_element(const ShapeIndex& index) { - return &Lookup(index)->data; + return &Lookup(index)->data.second; } template internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { - Node* node = &root_; + Node* node = &nodes_[0]; for (const int64 i : index) { CHECK_GE(i, 0); CHECK_LT(i, node->children.size()); - node = node->children[i].get(); + node = &nodes_[node->children[i]]; } return node; } @@ -552,13 +492,10 @@ const internal::ShapeTreeNode* ShapeTree::Lookup( /* static */ template template -Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, node.data)); - for (int64 i = 0; i < node.children.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR(ForEachHelper(func, *node.children[i], index)); - index->pop_back(); +Status ShapeTree::ForEachHelper(const Fn& func, + const std::vector& nodes) { + for (const auto& node : nodes) { + TF_RETURN_IF_ERROR(func(node.data.first, node.data.second)); } return Status::OK(); } @@ -566,14 +503,10 @@ Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, /* static */ template template -Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, &node->data)); - for (int64 i = 0; i < node->children.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR( - ForEachMutableHelper(func, node->children[i].get(), index)); - index->pop_back(); +Status ShapeTree::ForEachMutableHelper(const Fn& func, + std::vector* nodes) { + for (auto& node : *nodes) { + TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second)); } return Status::OK(); } @@ -581,40 +514,36 @@ Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, template template Status ShapeTree::ForEachElementWithStatus(const Fn& func) const { - ShapeIndex index; - return ForEachHelper(func, root_, &index); + return ForEachHelper(func, nodes_); } template template Status ShapeTree::ForEachMutableElementWithStatus(const Fn& func) { - ShapeIndex index; - return ForEachMutableHelper(func, &root_, &index); + return ForEachMutableHelper(func, &nodes_); } template template void ShapeTree::ForEachElement(const Fn& func) const { - ShapeIndex index; return ForEachHelper( [&func](const ShapeIndex& index, const T& data) { func(index, data); return Status::OK(); }, - root_, &index) + nodes_) .IgnoreError(); } template template void ShapeTree::ForEachMutableElement(const Fn& func) { - ShapeIndex index; return ForEachMutableHelper( [&func](const ShapeIndex& index, T* data) { func(index, data); return Status::OK(); }, - &root_, &index) + &nodes_) .IgnoreError(); } diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 4b6ab772811f4a6c6ffc1d10befc7122f883b8f9..dc5facf1581c07fbb74dfcee95025692938632bd 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace xla { namespace { @@ -421,8 +422,8 @@ TEST_F(ShapeTreeTest, IterateAndMutate) { } ++i; } - t.begin()->second = 78; - EXPECT_EQ(78, t.begin()->second); + (*t.begin()).second = 78; + EXPECT_EQ(78, (*t.begin()).second); i = 0; for (auto& index_to_data : t) { if (i == 0) { @@ -434,14 +435,14 @@ TEST_F(ShapeTreeTest, IterateAndMutate) { } ++i; } - EXPECT_EQ(78, t.begin()->second); - EXPECT_EQ(98, std::next(t.begin())->second); + EXPECT_EQ(78, (*t.begin()).second); + EXPECT_EQ(98, (*std::next(t.begin())).second); } TEST_F(ShapeTreeTest, IterateOrder) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; - for (auto& index_to_data : t) { + for (auto index_to_data : t) { v.push_back(index_to_data.first); } EXPECT_EQ(v, (std::vector{{}, @@ -479,7 +480,7 @@ TEST_F(ShapeTreeTest, ReverseIterateOrder) { TEST_F(ShapeTreeTest, IterateOrderLeaves) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; - for (auto& index_to_data : t.leaves()) { + for (auto index_to_data : t.leaves()) { v.push_back(index_to_data.first); } EXPECT_EQ(v, (std::vector{ @@ -502,5 +503,106 @@ TEST_F(ShapeTreeTest, ReverseIterateOrderLeaves) { })); } +void BM_Construct(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + for (int i = 0; i < iters; ++i) { + ShapeTree shape_tree(shape); + } +} + +void BM_ConstructUnowned(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + for (int i = 0; i < iters; ++i) { + ShapeTree shape_tree(&shape); + } +} + +void BM_Copy(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + ShapeTree copy = shape_tree; + tensorflow::testing::DoNotOptimize(copy); + } +} + +void BM_Move(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + ShapeTree copy = std::move(shape_tree); + shape_tree = std::move(copy); + } +} + +void BM_ForEach(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + shape_tree.ForEachMutableElement([](const ShapeIndex& index, int* data) { + tensorflow::testing::DoNotOptimize(index); + }); + } +} + +void BM_Iterate(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + for (auto& iter : shape_tree) { + tensorflow::testing::DoNotOptimize(iter.second); + } + } +} + +BENCHMARK(BM_Construct)->ArgPair(2, 8); +BENCHMARK(BM_ConstructUnowned)->ArgPair(2, 8); +BENCHMARK(BM_Copy)->ArgPair(2, 8); +BENCHMARK(BM_Move)->ArgPair(2, 8); +BENCHMARK(BM_ForEach)->ArgPair(2, 8); +BENCHMARK(BM_Iterate)->ArgPair(2, 8); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 7a897f6f8f99e65285e1be0757a55f703fc81c72..ce4d0079ee5eb28444509c712ec1a34037dc244a 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -42,17 +41,35 @@ limitations under the License. namespace xla { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + string ShapeIndex::ToString() const { - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(indices_, ","), "}"); + return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); } string ShapeIndexView::ToString() const { - return tensorflow::strings::StrCat( - "{", - tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_), - ","), - "}"); + return StrCat("{", + tensorflow::str_util::Join( + tensorflow::gtl::make_range(begin_, end_), ","), + "}"); +} + +bool ShapeIndexView::operator==(const ShapeIndexView& other) const { + if (size() != other.size()) { + return false; + } + for (auto it = begin(), other_it = other.begin(); it != end(); + ++it, ++other_it) { + if (*it != *other_it) { + return false; + } + } + return true; +} + +bool ShapeIndexView::operator!=(const ShapeIndexView& other) const { + return !(*this == other); } std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { @@ -67,18 +84,30 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { namespace { +// Returns whether the given primitive type corresponds to an array shape. +bool IsArrayPrimitiveType(PrimitiveType primitive_type) { + return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && + primitive_type != OPAQUE && primitive_type != TOKEN; +} + // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) { - return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + if (!ShapeUtil::SameElementType(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs element type != rhs element type"; + return false; + } + + if (ShapeUtil::IsTuple(lhs)) { + return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { return CompareShapes(l, r, compare_layouts); }); - } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) { - return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs); + } else if (!ShapeUtil::IsArray(lhs)) { + // Non-tuple, non-array tupes such as opaque and token types are trivially + // the same. + return true; } if (compare_layouts) { @@ -108,10 +137,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; return false; } - if (!ShapeUtil::SameElementType(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs element type != rhs element type"; - return false; - } return true; } @@ -154,8 +179,8 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(!ShapeUtil::IsTuple(shape)) - << "Tuples do not have a rank, shape: " << shape; + CHECK(ShapeUtil::IsArray(shape)) + << "Non-arrays do not have a rank, shape: " << shape; return shape.dimensions_size(); } @@ -182,8 +207,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShape( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape result; PopulateShape(element_type, dimensions, &result); return result; @@ -206,8 +230,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, int64 max_sparse_elements) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); @@ -254,6 +277,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return result; } +/* static */ Shape ShapeUtil::MakeTokenShape() { + Shape result; + result.set_element_type(TOKEN); + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); + return result; +} + /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape, Shape* tuple_shape) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape)); @@ -277,7 +307,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { - if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) { + if (!IsArray(shape)) { return false; } return primitive_util::BitWidth(shape.element_type()) == bits; @@ -303,6 +333,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( case C64: case TUPLE: case OPAQUE: + case TOKEN: return false; default: @@ -318,6 +349,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return primitive_util::IsFloatingPointType(shape.element_type()); } +/* static */ bool ShapeUtil::IsArray(const Shape& shape) { + return IsArrayPrimitiveType(shape.element_type()); +} + /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), IsTuple); @@ -371,7 +406,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape); + CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, @@ -386,23 +421,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return shape.element_type() == F32 && Rank(shape) == 0; } -/* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (IsTuple(shape)) { - string text = "("; - const char* prefix = ""; - for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape)); - prefix = ", "; - } - text += ")"; - return text; - } else { - return tensorflow::strings::StrCat( - tensorflow::str_util::Lowercase( - PrimitiveType_Name(shape.element_type())), - "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]"); - } -} namespace { @@ -453,48 +471,56 @@ StatusOr StringToPrimitiveType(const string& name) { } // namespace -/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { +/* static */ string ShapeUtil::HumanString(const Shape& shape) { if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, - HumanStringWithLayout(elem_shape)); + StrAppend(&text, prefix, HumanString(elem_shape)); prefix = ", "; } text += ")"; return text; - } else { - string result = tensorflow::strings::StrCat( - LowercasePrimitiveTypeName(shape.element_type()), "["); - for (int i = 0; i < shape.dimensions().size(); i++) { - tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "", - shape.dimensions(i)); + } + return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", + tensorflow::str_util::Join(shape.dimensions(), ","), "]"); +} + +/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { + if (IsTuple(shape)) { + string text = "("; + const char* prefix = ""; + for (const Shape& elem_shape : shape.tuple_shapes()) { + StrAppend(&text, prefix, HumanStringWithLayout(elem_shape)); + prefix = ", "; } - result += "]"; - if (!IsScalar(shape) && !IsOpaque(shape)) { - if (LayoutUtil::HasLayout(shape)) { - tensorflow::strings::StrAppend(&result, - LayoutUtil::HumanString(shape.layout())); - } + text += ")"; + return text; + } + string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "["); + for (int i = 0; i < shape.dimensions().size(); i++) { + StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); + } + result += "]"; + if (!IsScalar(shape) && IsArray(shape)) { + if (LayoutUtil::HasLayout(shape)) { + StrAppend(&result, LayoutUtil::HumanString(shape.layout())); } - return result; } + return result; } /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) { std::vector parameters; for (auto& shape : program_shape.parameters()) { const int i = parameters.size(); - parameters.push_back( - tensorflow::strings::StrCat(i < program_shape.parameter_names_size() - ? program_shape.parameter_names(i) - : "(unknown)", - ": ", HumanString(shape))); + parameters.push_back(StrCat(i < program_shape.parameter_names_size() + ? program_shape.parameter_names(i) + : "(unknown)", + ": ", HumanString(shape))); } - return tensorflow::strings::StrCat( - "(", tensorflow::str_util::Join(parameters, ", "), ") -> ", - HumanString(program_shape.result())); + return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ", + HumanString(program_shape.result())); } namespace { @@ -564,14 +590,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the primitive element type. TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, StringToPrimitiveType(element_type_string)); - if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || - primitive_type == OPAQUE) { + if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { return InvalidArgument("Invalid element type string: \"%s\".", element_type_string.c_str()); } Shape result; - if (format_string.empty() && layout_string.empty()) { + if (primitive_type == OPAQUE) { + result = ShapeUtil::MakeOpaqueShape(); + } else if (primitive_type == TOKEN) { + result = ShapeUtil::MakeTokenShape(); + } else if (format_string.empty() && layout_string.empty()) { // Create a shape without a layout set. result = ShapeUtil::MakeShape(primitive_type, dimensions); } else if (format_string == "sparse") { @@ -616,43 +645,44 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameDimensions(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringElementType); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) && + CompatibleIgnoringElementType(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringFpPrecision); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { - return CompatibleIgnoringElementType(lhs, rhs); - } - return false; } /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, @@ -674,10 +704,6 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { switch (primitive_type) { case PRED: return sizeof(int8); - case TUPLE: - LOG(FATAL) << "tuples have no definitive size"; - case OPAQUE: - LOG(FATAL) << "opaque have no definitive size"; case S8: return sizeof(int8); case S16: @@ -704,6 +730,13 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(double); case C64: return sizeof(complex64); + case TOKEN: + // Tokens require no space. + return 0; + case TUPLE: + case OPAQUE: + LOG(FATAL) << PrimitiveType_Name(primitive_type) + << " primitive type has no definitive size"; default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; } @@ -712,28 +745,32 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_NE(OPAQUE, shape.element_type()); if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); + } else if (IsArray(shape)) { + int64 byte_size = ByteSizeOfElements(shape); + if (LayoutUtil::IsSparseArray(shape)) { + byte_size += ByteSizeOfSparseIndices(shape); + } + return byte_size; + } else if (shape.element_type() == TOKEN) { + return 0; } - int64 byte_size = ByteSizeOfElements(shape); - if (LayoutUtil::IsSparseArray(shape)) { - byte_size += ByteSizeOfSparseIndices(shape); - } - return byte_size; + LOG(FATAL) << PrimitiveType_Name(shape.element_type()) + << " primitive type has no definitive size"; } /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_EQ(TUPLE, shape.element_type()); + CHECK_EQ(TUPLE, shape.element_type()); CHECK_GT(pointer_size, 0); return pointer_size * shape.tuple_shapes_size(); } /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(ShapeUtil::IsArray(shape)); + CHECK(ShapeUtil::IsArray(shape)); int64 allocated_element_count; if (LayoutUtil::IsSparseArray(shape)) { @@ -758,13 +795,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(LayoutUtil::IsSparseArray(shape)); + CHECK(LayoutUtil::IsSparseArray(shape)); return LayoutUtil::MaxSparseElements(shape.layout()) * ShapeUtil::Rank(shape) * sizeof(int64); } /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("shape has invalid element type: %s", + shape.ShortDebugString().c_str()); + } if (shape.element_type() == TUPLE) { if (shape.dimensions_size() != 0) { return InvalidArgument("tuples must not have dimensions specified"); @@ -780,10 +821,24 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape.tuple_shapes_size() > 0) { return InvalidArgument("non-tuple shape has tuple_shapes field"); } - if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { - return InvalidArgument("shape has invalid element type: %s", - shape.ShortDebugString().c_str()); + + // Tokens and opaques can should not have layout or dimensions. + if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) { + if (shape.dimensions_size() != 0) { + return InvalidArgument( + "shape has %s element type, but has dimensions field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + if (shape.has_layout()) { + return InvalidArgument( + "shape has %s element type, but has layout field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + return Status::OK(); } + if (Rank(shape) != shape.dimensions_size()) { return InvalidArgument( "shape's rank is mismatched with dimension count; rank=%lld " @@ -863,7 +918,30 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return !IsTuple(GetSubshape(shape, index)); } +/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + int64 count = 0; + ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + ++count; + } + }); + return count; +} + +/* static */ std::vector ShapeUtil::GetLeafShapes( + const Shape& shape) { + std::vector leaves; + ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + leaves.emplace_back(index, sub_shape); + } + }); + return leaves; +} + /* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { + CHECK(IsArray(shape)); + std::vector dimension_sizes; std::vector degenerate_dimensions; for (int64 i = 0; i < shape.dimensions_size(); ++i) { @@ -1028,6 +1106,9 @@ Status ForEachMutableSubshapeHelper( /* static */ std::tuple, std::vector> ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, const Shape& shape_post) { + CHECK(IsArray(shape_pre)); + CHECK(IsArray(shape_post)); + auto nil = std::make_tuple(false, std::vector(), std::vector()); std::vector deleted_indices; @@ -1085,6 +1166,9 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, /* static */ std::vector> ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + // Unmodified dimensions are merely common factors of rank 1. auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); @@ -1138,8 +1222,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - CHECK(LayoutUtil::HasLayout(input_shape) && - LayoutUtil::HasLayout(output_shape)); + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + CHECK(LayoutUtil::HasLayout(input_shape)); + CHECK(LayoutUtil::HasLayout(output_shape)); if (!SameElementType(input_shape, output_shape)) { return false; @@ -1301,6 +1387,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + int64 input_rank = Rank(input_shape); int64 output_rank = Rank(output_shape); @@ -1435,6 +1524,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { + CHECK(IsArray(shape)); shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); @@ -1456,6 +1546,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { + CHECK(IsArray(shape)); std::vector dims_to_delete; for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { if (!p(i)) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 82c75f85d838f94cb040e56d59d0e012af5e0db0..3853ada6ba65dbb1ac0754bcf753b4553ec260e7 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -132,6 +133,9 @@ class ShapeIndexView { return ShapeIndexView(new_begin, end_); } + bool operator==(const ShapeIndexView& other) const; + bool operator!=(const ShapeIndexView& other) const; + string ToString() const; private: @@ -150,12 +154,22 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); // properties, which do invariant checks before / after the operation. class ShapeUtil { public: + // Data structure which describes the coordinates and the shape, of a tuple + // shaped sub-shape. + struct IndexedShape { + IndexedShape() = default; + IndexedShape(ShapeIndex index, Shape shape) + : index(std::move(index)), shape(std::move(shape)) {} + ShapeIndex index; + Shape shape; + }; + // Returns the number of elements are contained within the provided shape; // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes // may not actually be able to store this number of elements. See // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of // elements that can be stored in a sparse shape. - // Precondition: !IsTuple(shape) + // Precondition: IsArray(shape) static int64 ElementsIn(const Shape& shape); // Returns true if 'shape' has zero elements. @@ -166,13 +180,11 @@ class ShapeUtil { // shapes. This includes only the size of the top-level buffer. For example, a // tuple is stored as an array of pointers to other buffers. In this case, // this method only returns the size of the pointer array. - // Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) && - // !ShapeUtil::IsOpaque(shape) static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1); // Returns the number of bytes used to store the primitive_type. // - // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) + // Precondition: ShapeUtil::IsArray(shape) static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); // Returns the number of bytes required to store the tuple member pointers for @@ -231,7 +243,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that they have the same element type + // point types; otherwise, checks that that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { @@ -279,10 +291,10 @@ class ShapeUtil { // Scalar-specific static bool IsScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0; + return IsArray(shape) && Rank(shape) == 0; } static bool IsEffectiveScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0; + return IsArray(shape) && TrueRank(shape) == 0; } static bool IsScalarF32(const Shape& shape); @@ -311,6 +323,10 @@ class ShapeUtil { // into a custom operation. static Shape MakeOpaqueShape(); + // Creates a token shape. Values of this shape are used for ordering + // side-effecting operations. + static Shape MakeTokenShape(); + // Appends a shape to the given tuple. static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape); @@ -410,11 +426,15 @@ class ShapeUtil { return shape.element_type() == OPAQUE; } + // Returns whether the shape is an token value used for ordering + // side-effecting operations. + static bool IsToken(const Shape& shape) { + return shape.element_type() == TOKEN; + } + // Returns whether the shape is an array. Note that scalars are considered // arrays. - static bool IsArray(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape); - } + static bool IsArray(const Shape& shape); // Returns whether the shape is a tuple with at least one element which is // also a tuple. @@ -461,6 +481,13 @@ class ShapeUtil { // shape. static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index); + // Returns the number of leaves in the shape. + static int64 GetLeafCount(const Shape& shape); + + // Retrieves all the leaf shapes and their indexes, in the order walked by + // the ForEachSubshape() API. + static std::vector GetLeafShapes(const Shape& shape); + // Calls the given visitor function for each subshape of the given shape. // Subshapes are visited in DFS pre-order starting with the entire shape // (index {}). @@ -626,6 +653,28 @@ class ShapeUtil { .IgnoreError(); } + // These convenience wrappers don't take `base`, `count` and `incr` + // explicitly, but iterate over every element in `shape` instead. + + template + static Status ForEachIndexWithStatus(const Shape& shape, + const FnType& visitor_function) { + std::vector base(shape.dimensions_size()); + std::vector incr(shape.dimensions_size(), 1); + return ForEachIndexWithStatus(shape, base, + /*count=*/AsInt64Slice(shape.dimensions()), + incr, visitor_function); + } + + template + static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { + ForEachIndexWithStatus(shape, + [&](tensorflow::gtl::ArraySlice indices) { + return StatusOr(visitor_function(indices)); + }) + .IgnoreError(); + } + // A parallel version of ForEachIndex(WithStatus). This can only be used if // the visitor_function is thread-safe and the order of iteration does not // matter. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index f7675e97da7b061bde063e5093256c2288f99c98..ecdb6532f1d743c7dacc266eeba615e19748ee27 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -93,12 +93,14 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { } TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { - string shape_string = "(f32[1],(f32[2]), f32[3])"; + string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ ShapeUtil::MakeShape(F32, {1}), - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeShape(F32, {3}), }); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) @@ -136,6 +138,23 @@ TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST(ShapeUtilTest, ParseOpaqueType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString("opaque[]")); + Shape expected = ShapeUtil::MakeOpaqueShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseTokenType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]")); + Shape expected = ShapeUtil::MakeTokenShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + TEST(ShapeUtilTest, ParseInvalidShapeString) { string shape_strings[] = { "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", @@ -295,6 +314,9 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); + + EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN)); + EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape())); } TEST(ShapeUtilTest, ByteSizeOfWithPadding) { @@ -449,19 +471,21 @@ TEST(ShapeUtilTest, IsLeafIndex) { TEST(ShapeUtilTest, HumanString) { Shape opaque = ShapeUtil::MakeOpaqueShape(); + Shape token = ShapeUtil::MakeTokenShape(); Shape scalar = ShapeUtil::MakeShape(F32, {}); Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); - Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix}); + Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token}); EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); + EXPECT_EQ("token[]", ShapeUtil::HumanString(token)); EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", ShapeUtil::HumanString(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(nested_tuple)); EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); @@ -470,8 +494,10 @@ TEST(ShapeUtilTest, HumanString) { EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", ShapeUtil::HumanStringWithLayout(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})", - ShapeUtil::HumanStringWithLayout(nested_tuple)); + EXPECT_EQ( + "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " + "token[])", + ShapeUtil::HumanStringWithLayout(nested_tuple)); ProgramShape prog = ShapeUtil::MakeProgramShape( {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); @@ -481,8 +507,9 @@ TEST(ShapeUtilTest, HumanString) { "(unknown): u32[1,2], " "(unknown): s32[3,4], " "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " - "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); prog.add_parameter_names("arg0"); @@ -497,8 +524,10 @@ TEST(ShapeUtilTest, HumanString) { "matrix: u32[1,2], " "matrix2: s32[3,4], " "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " - "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " + "token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); } diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h index 4eb3bf3766412d5d9a8e78a4652807c5eaeef6ee..69abb51852ac09e8d357a9ba7924efc348ef2001 100644 --- a/tensorflow/compiler/xla/status.h +++ b/tensorflow/compiler/xla/status.h @@ -21,7 +21,7 @@ limitations under the License. namespace xla { -using tensorflow::Status; +using tensorflow::Status; // TENSORFLOW_STATUS_OK } // namespace xla diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index 7d76370e85d57fd6e27ee2d1ca1df068ccb5405a..377a618ffbd99316d409130df8a39f352664dee0 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -413,7 +413,7 @@ TEST(StatusOr, TestPointerValueConst) { EXPECT_EQ(&kI, thing.ValueOrDie()); } -// NOTE(tucker): tensorflow::StatusOr does not support this kind +// NOTE(tucker): StatusOr does not support this kind // of resize op. // TEST(StatusOr, StatusOrVectorOfUniquePointerCanResize) { // using EvilType = std::vector>; diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 17bae2e4f611268df824ce793c75ba1c95573455..8918350135fbb86973b228b35f5873fea8695b2f 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -40,13 +40,10 @@ class Literal; namespace testing { namespace internal_status { -inline const ::tensorflow::Status& GetStatus( - const ::tensorflow::Status& status) { - return status; -} +inline const Status& GetStatus(const Status& status) { return status; } template -inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { +inline const Status& GetStatus(const StatusOr& status) { return status.status(); } } // namespace internal_status @@ -57,21 +54,17 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { // The following macros are similar to macros in gmock, but deliberately named // differently in order to avoid conflicts in files which include both. -// Macros for testing the results of functions that return tensorflow::Status or +// Macros for testing the results of functions that return Status or // StatusOr (for any type T). -#define EXPECT_IS_OK(expression) \ - EXPECT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) -#define EXPECT_IS_NOT_OK(expression) \ - EXPECT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_OK(expression) \ + EXPECT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_NOT_OK(expression) \ + EXPECT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_OK -#define ASSERT_IS_OK(expression) \ - ASSERT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_OK(expression) \ + ASSERT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_NOT_OK -#define ASSERT_IS_NOT_OK(expression) \ - ASSERT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_NOT_OK(expression) \ + ASSERT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index b982cf0dbc4ed00b9c0b0d98c1ec4e5584860717..e7e0a19db0516e4210f6bb78d6b5e6968bf78b2a 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -87,12 +87,12 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla:literal_comparison", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -117,11 +117,11 @@ cc_library( "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -138,8 +138,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -152,7 +152,6 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", @@ -188,8 +187,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -288,8 +285,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -313,7 +308,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -335,7 +329,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -378,7 +371,6 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -398,7 +390,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -422,8 +413,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -450,8 +439,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -472,7 +459,6 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -491,7 +477,6 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -528,7 +513,6 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -552,7 +536,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -572,8 +555,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -598,8 +579,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -626,7 +605,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -641,6 +619,7 @@ xla_test( xla_test( name = "exhaustive_f32_elementwise_op_test", + size = "enormous", srcs = ["exhaustive_f32_elementwise_op_test.cc"], backends = [ "cpu", @@ -648,7 +627,6 @@ xla_test( ], shard_count = 48, tags = [ - "enormous", "manual", "notap", ], @@ -697,7 +675,6 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -720,8 +697,8 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -741,7 +718,6 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -766,7 +742,6 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -790,7 +765,6 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -802,30 +776,42 @@ xla_test( ], ) +CONVOLUTION_TEST_DEPS = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", +] + xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = [ - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], + deps = CONVOLUTION_TEST_DEPS, +) + +xla_test( + name = "convolution_test_gpu_alternative_layout", + timeout = "long", + srcs = ["convolution_test.cc"], + backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, + backends = ["gpu"], + shard_count = 25, + deps = CONVOLUTION_TEST_DEPS, ) xla_test( @@ -843,7 +829,6 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -868,7 +853,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -930,8 +914,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -960,8 +942,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1002,7 +982,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1055,8 +1034,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1078,7 +1055,6 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1108,8 +1084,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -1221,9 +1195,25 @@ xla_test( ], deps = [ ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "token_hlo_test", + srcs = ["token_hlo_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1240,8 +1230,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1281,7 +1269,6 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1304,7 +1291,6 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1344,7 +1330,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1362,7 +1347,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1388,8 +1372,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1411,7 +1393,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1483,8 +1464,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -1532,7 +1511,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1545,6 +1523,30 @@ xla_test( ], ) +xla_test( + name = "cross_replica_sum_test", + srcs = ["cross_replica_sum_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], @@ -1574,8 +1576,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1596,7 +1596,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1620,8 +1619,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1642,7 +1639,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1661,7 +1657,6 @@ xla_test( srcs = ["execution_profile_test.cc"], deps = [ ":client_library_test_base", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1676,7 +1671,6 @@ xla_test( args = ["--xla_hlo_profile"], deps = [ ":client_library_test_base", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1782,8 +1776,6 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1811,8 +1803,6 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1850,8 +1840,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1880,8 +1868,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1949,8 +1935,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -2051,7 +2035,6 @@ xla_test( ":local_client_test_base", ":test_utils", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index e8a5efe796a9209307ecfa343b89f66ff2a34e0f..36a706496918ac8c15780473019e2a8d098ffa22 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -2225,6 +2225,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { ComputeAndCompareR1(&builder, {32, 31, 27, 15, 9, 3, 0}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) { + XlaBuilder builder(TestName()); + auto a = + builder.ConstantR1({0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); + builder.Clz(a); + + ComputeAndCompareR1(&builder, {64, 63, 32, 1, 0}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // a ------ (add) --------- (add) // / / diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 4e65cf11f3f1a027e1adc5a89930caba28958fea..ca337e78840e77377719636cd4cf33af2578210d 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 34c86e007beea1cbac04641bdbdab62dc567f13e..3a0f51fc66d65c8684bd607b9e8103559cd4d8d4 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -671,7 +671,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { @@ -684,7 +684,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 6ebbf7191833ef85ee4a48cc96c0a3be38c71228..51b9f0d3e330e73f5d110f0a62f824179d5c7cf7 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR0(42.0), *result, - error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0(42.0), *result, + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { @@ -62,9 +62,9 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, - error_spec_); + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { @@ -85,13 +85,13 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralView::Create(*result, {0}), error_spec_); + LiteralSlice(*result, {0}), error_spec_)); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralView::Create(*result, {1}), error_spec_); + LiteralSlice(*result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { @@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, - error_spec_); + EXPECT_TRUE( + LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { @@ -125,9 +125,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, - error_spec_); + EXPECT_TRUE( + LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { @@ -142,10 +142,10 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), - *result, error_spec_); + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { @@ -166,8 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -196,8 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -218,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result, - error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { @@ -238,8 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { @@ -260,8 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -291,8 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index a43ca3d5ca2ba39ba9c16213e985e50bf39c0b7d..5fd33b50c94356839bbed58acd43b7d0286f4a7e 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 41f9a5f66649dd0d697287c5e2af322fc63c1396..bf8ed4d9fb0bc61b86ef0b5872711a122a3d416b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -178,8 +177,7 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral( error, shape_with_layout)); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function arguments, const std::function choose; - choose = [&, this](int64 index) -> tensorflow::Status { + std::function choose; + choose = [&, this](int64 index) -> Status { if (index < arguments.size()) { // Try out all layouts for the operand. TF_ASSIGN_OR_RETURN(auto literal, @@ -230,7 +227,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); layout_strings.pop_back(); - return tensorflow::Status::OK(); + return Status::OK(); } std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); @@ -248,7 +245,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( layout_strings.pop_back(); } while ( std::next_permutation(minor_to_major.begin(), minor_to_major.end())); - return tensorflow::Status::OK(); + return Status::OK(); } // Every argument has an assigned layout. @@ -263,13 +260,13 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( tensorflow::strings::StrAppend(&error_message, str, " "); } verify_output(*actual, error_message); - return tensorflow::Status::OK(); + return Status::OK(); }; return choose(0); } -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { @@ -297,7 +294,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + converted_expected = Literal::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -311,7 +308,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } auto expect_equal = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message; }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( @@ -323,11 +320,11 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectEqual(*expected_ptr, *actual); - return tensorflow::Status::OK(); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); + return Status::OK(); } -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, ErrorSpec error, const Shape* shape_with_layout) { @@ -349,7 +346,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + converted_expected = Literal::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -363,7 +360,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } auto expect_near = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)) + << error_message; }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( @@ -375,8 +373,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error); - return tensorflow::Status::OK(); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); + return Status::OK(); } void ClientLibraryTestBase::ComputeAndCompareR1U8( @@ -407,7 +405,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual)); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -419,7 +417,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(expected, *actual, error); + EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -431,7 +429,7 @@ void ClientLibraryTestBase::ComputeAndCompare( } std::unique_ptr reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*reference, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -444,7 +442,7 @@ void ClientLibraryTestBase::ComputeAndCompare( } std::unique_ptr reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*reference, *result, error); + EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error)); } StatusOr, std::unique_ptr>> @@ -562,7 +560,7 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { return builder->ConstantLiteral( - use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); + use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); } std::unique_ptr @@ -583,7 +581,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( const Literal* param_literal = &literal; std::unique_ptr converted_literal; if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + converted_literal = Literal::ConvertF32ToBF16(literal); param_literal = converted_literal.get(); } std::unique_ptr data = diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 16e838e60ffbd7b22878ac21c760ade599f33594..0499fec5898a42affa0e0a712dee10187355c13e 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -188,11 +188,11 @@ class ClientLibraryTestBase : public ::testing::Test { const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. - tensorflow::Status ComputeAndCompareLiteralWithStatus( + Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); - tensorflow::Status ComputeAndCompareLiteralWithStatus( + Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); @@ -378,12 +378,12 @@ class ClientLibraryTestBase : public ::testing::Test { ExecutionOptions execution_options_; private: - tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( + Status ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output); - tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( + Status ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function ClientLibraryTestBase::CreateR0Parameter( XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -555,7 +555,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -569,7 +569,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -583,7 +583,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 0b425b93bb144e395baef2bcf074fd6e7991630b..08671cf62445826649b5c97003f998ae98a59d97 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -62,9 +62,9 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { TF_ASSERT_OK_AND_ASSIGN( auto computed, client_->Transfer(*data, &expected_literal->shape())); - LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), - computed->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( + expected_literal->shape(), computed->shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } } @@ -91,9 +91,9 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - LiteralView::Create(*result, {0})); + LiteralSlice(*result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - LiteralView::Create(*result, {1})); + LiteralSlice(*result, {1})); EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); @@ -142,7 +142,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { auto result_literal, client_->Transfer(*results[0], &expected_result->shape())); - LiteralTestUtil::ExpectEqual(*expected_result, *result_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index ecce599a8a3bd588c11d6bb9ba461b5a917197db..50a006964869b3e5dce431d441f7cd81af9df910 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" @@ -50,8 +49,8 @@ class CompilationCacheTest : public ClientLibraryTestBase { /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*Literal::CreateR0(expected_result), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR0(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -67,8 +66,8 @@ class CompilationCacheTest : public ClientLibraryTestBase { .ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data_handle).ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*Literal::CreateR2(expected_result), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index bf4b8fb0bcf229b4e8649b3920dcba1ae0579831..ba22530f1cfee56337f862c25122d399dbf0f1e4 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -208,7 +208,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR1({4, 6}); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -222,7 +222,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR0(5); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -244,9 +244,9 @@ XLA_TEST_F(ComputeConstantTest, Layout) { std::unique_ptr expected_literal = Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); - LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), - computed->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( + expected_literal->shape(), computed->shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } } diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 4743673561a665ca8670a56bf15d85a74073e472..916ffadbc798ec0dd016f45b0bc4c36233455ee7 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -21,13 +21,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -169,9 +167,9 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectR2Near( - {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_); + {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_); LiteralTestUtil::ExpectR1Near( - {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_); + {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 4ef0a77884c90b9fe32f96d3361fa3d80bde623b..722d882471a41a75c1e5e60f8c1a151b76c7e004 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -249,10 +249,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { -1.99f, -2.0f, -2.01f, - 0x1.FFFFFEp+62F, - 0x1.FFFFFCp+62F, - -0x1.FFFFFEp+62F, - -0x1.FFFFFCp+62F}; + 9223371487098961920.f, + 9223370937343148032.f, + -9223371487098961920.f, + -9223370937343148032.f}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 947959beb144e1509a77ad2f94b8493de46ba6f2..346bb3a3996ee5bf662b0f74dd0c2096efbf5295 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -47,9 +47,9 @@ class ConvolutionTest : public ClientLibraryTestBase { #if XLA_TEST_BACKEND_GPU // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial // convolution. So relax the absolute error threshold. - ErrorSpec error_spec_ = ErrorSpec(1e-2); + ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-4); #else - ErrorSpec error_spec_ = ErrorSpec(1e-4); + ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-4); #endif }; diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 50d6e25d868c4964ff35023b43a3734ed115bbb8..fea850dc135e33fe098aa755c6fdd93319cd2837 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 155fbacf58d81cff27939c142c8f30158cef4e00..2b3390ca98cb2922410d451c06811aa9d4ff8c0b 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -49,7 +49,7 @@ class CopyOpTest : public HloTestBase { module->AddEntryComputation(std::move(computation)); std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectEqual(literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result)); } void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); @@ -253,7 +253,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*empty, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b151187c4b8f01c5b46ccadf27d2e22a7c902e98 --- /dev/null +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class TrivialCrossReplicaSumTest : public HloTestBase {}; + +// Currently the CPU and GPU backends only support CrossReplicaSum with one +// replica. But we can at least check this. + +XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p = f32[3] parameter(0) + ROOT crs = f32[3] cross-replica-sum(p), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal = Literal::CreateR1({1, 2, 3}); + EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); +} + +XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p0 = f32[3] parameter(0) + p1 = f32[2] parameter(1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal0 = Literal::CreateR1({1, 2, 3}); + auto literal1 = Literal::CreateR1({10, 20}); + EXPECT_EQ( + *Literal::MakeTuple({literal0.get(), literal1.get()}), + *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); +} + +// On the GPU backend, constants get special handling. Someone might pass a +// constant to CRS to e.g. count the number of replicas -- we need to make sure +// it works. +XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p0 = f32[3] parameter(0) + p1 = f32[2] constant({10, 20}) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal0 = Literal::CreateR1({1, 2, 3}); + auto literal1 = Literal::CreateR1({10, 20}); + EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}), + *ExecuteAndTransfer(std::move(module), {literal0.get()})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index c76e5aabf4b8a3463b2971654d0a6cf0dd594626..bfe688e20d182d581c3e3b545ac2289413deef7c 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index d0ada2474830390e50a90c4c41aa42166d6e8ea5..12789fe66530fe03eb33316eda652336f29971ab 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index b236cf00a8053af296c8b5f1e8e5db937d5e6fd6..0fd846cef8095a857dd7b2c12d8afdf409e2bd66 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -61,7 +61,7 @@ using TypesF16F32F64CF64 = ::testing::Types; #endif // Check that we can safely pass an input tuple's elements to a dot operation. -TEST_F(DotOperationTest, DotOfInputTupleElem) { +XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaBuilder builder(TestName()); XlaOp param; diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index bfb83faf5222b8ca5ceceebf7f2f976ec803245e..49f3a10d227f2f9edfe76405ba13498fe822f8d8 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -53,9 +53,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR1Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1}); + void TestR1OOB() { + // Slice at dimension boundaries, but with out of bounds indices. + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7}); } template @@ -78,10 +78,10 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR2Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR2OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3}, - {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}}); + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); } template @@ -106,11 +106,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR3Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR3OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, - {2, 1, 2}, {{{6, 5}}, {{12, 11}}}); + {2, 1, 2}, {{{5, 6}}, {{11, 12}}}); } template @@ -199,19 +199,19 @@ class DynamicSliceTest : public ClientLibraryTestBase { XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } @@ -332,17 +332,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void TestWrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestOOB() { + // // Slice at dimension boundaries, but with out of bounds indices. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6}, - {10, 1, 2, 3, 4, 5, 8, 9}); + {0, 1, 2, 3, 4, 8, 9, 10}); // R2 Shape: [3, 3] RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2}, - {{1, 2, 3}, {4, 5, 6}, {11, 8, 10}}); + {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}}); // R3 Shape: [2, 3, 2] RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, - {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}}); + {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}}); } template @@ -476,20 +476,19 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { Array3D input_values(kSeq, kBatch, kDim); Array3D update_values(size, kBatch, kDim); Array3D expected_values(kSeq, kBatch, kDim); + index = std::min(std::max(0, index), kSeq - size); input_values.FillIota(static_cast(0)); T value = static_cast(10); update_values.FillIota(static_cast(value)); // TODO(b/34128753) Expected values may vary depending on backend when - // the update wraps. According to documentation, the results are technically - // implementation specific where the update is out of bounds, and hence - // we don't really know what to pass into ComputeAndCompareR3. + // the indices are out of bounds. expected_values.FillIota(static_cast(0)); for (int i = 0; i < size; i++) { for (int j = 0; j < kBatch; j++) { for (int k = 0; k < kDim; k++) { - expected_values((index + i) % kSeq, j, k) = value++; + expected_values(index + i, j, k) = value++; } } } @@ -547,12 +546,10 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32WrapBF16) { - TestWrap(); -} -XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { // Slice at dimension start. @@ -615,37 +612,37 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { // Tests for simple R3 case where the update is contiguous (i.e. the minor // two dimensions are not sliced). XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index b947f8208a5fa3f5a396ebc7a234afbf7ac3d900..e6f79b5ac55dddfbb213a36cadbee53bc9443d9d 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -118,9 +118,9 @@ class FusionTest : public HloTestBase { auto expected = Literal::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { - LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); } else { - LiteralTestUtil::ExpectEqual(*expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } } @@ -221,9 +221,9 @@ XLA_TEST_F(FusionTest, Test) { const4, reshape3, add2, const1, const0}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*Literal::CreateR2({{0.5}, {2.72}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), - ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2({{0.5}, {2.72}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } // Test whether we emit appropriate code for parameters of fusion instructions. @@ -247,9 +247,9 @@ XLA_TEST_F(FusionTest, Parameter) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*Literal::CreateR2({{-1.0, 0.0, 1.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), - ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2({{-1.0, 0.0, 1.0}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, RandomizedParallelPartition) { @@ -307,9 +307,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, ReshapeToScalar) { @@ -322,8 +322,9 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(5), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(5), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { @@ -336,9 +337,9 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { @@ -351,9 +352,9 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by1by1_) { @@ -366,8 +367,9 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__1by1by1) { @@ -380,8 +382,9 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR3({{{7}}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR3({{{7}}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__) { @@ -394,8 +397,9 @@ XLA_TEST_F(FusionTest, Reshape__) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { @@ -408,9 +412,9 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_2by3) { @@ -423,9 +427,9 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 4}, {2, 5}, {3, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_3by3) { @@ -438,9 +442,9 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reverse) { @@ -454,8 +458,9 @@ XLA_TEST_F(FusionTest, Reverse) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({3, 2, 1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({3, 2, 1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReverseNegate) { @@ -471,8 +476,9 @@ XLA_TEST_F(FusionTest, ReverseNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-3, -2, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-3, -2, -1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, BroadcastNegate) { @@ -488,8 +494,9 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-1, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-1, -1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { @@ -505,8 +512,9 @@ XLA_TEST_F(FusionTest, SliceNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-1, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-1, -3}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { @@ -526,8 +534,9 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { /*instructions_to_fuse=*/{negate3, dynamic_slice2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-2, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-2, -3}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { @@ -543,8 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -2}, {-3, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } // TODO(b/64070202): Investigate failure. @@ -561,8 +571,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -3}, {-2, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -591,8 +602,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(15), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(15), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { @@ -612,8 +624,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(-15), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(-15), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { @@ -661,9 +674,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{462, 2145}, {24871, 62491}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } // When a constant (or other op) which has multiple users is imported @@ -697,8 +710,9 @@ XLA_TEST_F(FusionTest, SharedConstant) { // fused instruction contains the constant(2), the parameter, and 4 adds EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({8}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({8}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 130456e61ca8a217e903d2ddecc487f29a098ce1..143ffbdeb409d91ab6d46d386aa5ff98ebc4ae10 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" // NB! TODO(b/74360564): These tests do not test out of bounds behavior since // that hasn't been specced yet. @@ -41,7 +41,7 @@ class GatherOperationTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text, config)); + ParseHloString(hlo_text, config)); EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); } }; @@ -629,8 +629,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { client_->ExecuteParallel(computation_instances)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, client_->Transfer(*(result_data[0]))); - LiteralTestUtil::ExpectEqual( - *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}))); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 12598579c7032e954c4a4875ab8e6475b112f5ae..242cc5db11ff2bdf69209df7537216573d8afbf3 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -94,18 +94,14 @@ HloTestBase::HloTestBase(se::Platform* test_platform, /* static */ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { - HloModuleConfig config; - auto debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_max_kernel_unroll_factor(1); - config.set_debug_options(debug_options); - - return MakeUnique(name, VersionedComputationHandle(), config); + return MakeUnique(name, GetModuleConfigForTest()); } /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); return debug_options; } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 9539ae06801628baedaea69024b7760ebefa6e3a..249da87f489324ed9d377cc46a15cef5a9e74192 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -66,6 +66,15 @@ namespace xla { // // For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { + public: + // Creates a new HLO module for a test. The module created will have + // TestName() for its name; it will also automatically populate its debug + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. + static std::unique_ptr CreateNewModule( + const string& name = TestName()); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the @@ -80,19 +89,18 @@ class HloTestBase : public ::testing::Test { ~HloTestBase() override {} - // Creates a new HLO module for a test. The module created will have - // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. If you want a fresh HloModule object and - // then add HloComputations to it, it's recommended to use this method in your - // tests. - static std::unique_ptr CreateNewModule( - const string& name = TestName()); - // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file. static DebugOptions GetDebugOptionsForTest(); + // Gets an HloModuleConfig with options appropriate for tests. + static HloModuleConfig GetModuleConfigForTest() { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + return config; + } + // Executes the given module and return the result as a Literal. StatusOr> Execute( std::unique_ptr module, diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index da4cf4ae0c31bc194cd2ec9b845df36afbde69b0..22c664d1426c598dbb695ff1b66ce009b0a19c00 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -41,14 +41,17 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - VerifyModule(); + VerifyModule(module_.get()); + } + for (int i = 0; i < modules_.size(); ++i) { + VerifyModule(modules_.at(i).get()); } HloTestBase::TearDown(); } -void HloVerifiedTestBase::VerifyModule() { - HloVerifier verifier; - xla::StatusOr mutated = verifier.Run(module_.get()); +void HloVerifiedTestBase::VerifyModule(HloModule* module) { + HloVerifier verifier(/*allow_mixed_precision=*/true); + xla::StatusOr mutated = verifier.Run(module); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); } else { @@ -59,15 +62,20 @@ void HloVerifiedTestBase::VerifyModule() { HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = CreateNewModule(); + module_ = HloTestBase::CreateNewModule(); } return *module_; } +HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { + modules_.emplace_back(HloTestBase::CreateNewModule()); + return modules_.back().get(); +} + void HloVerifiedTestBase::ParseAndVerifyModule( tensorflow::StringPiece hlo_text) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text)); - VerifyModule(); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + VerifyModule(module_.get()); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index e5bb14a8839acbdef8fd2b79bb0f574c46ea3d40..5b59cc77f61b05092d3afb331e73932c9edc5840 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -52,11 +52,23 @@ class HloVerifiedTestBase : public HloTestBase { shape_verifier_ = std::move(shape_verifier); } + // Creates a new module for a test, and stores it in modules_ so it can be + // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent + // creation of unverified modules. + HloModule* CreateNewModule(const string& name = TestName()); + + // It is confusing to store modules created by module() and CreateNewModule() + // in different fields, but it allows us to migrate tests to + // HloVerifiedTestBase more easily, so it's a win because we can verify more + // modules. See b/80488902. private: - std::unique_ptr module_; // Lazily populated. Access via module(). + // Lazily populated. Access via module(). + std::unique_ptr module_; + // Populated by calls to CreateNewModule. + std::vector> modules_; std::unique_ptr shape_verifier_; bool tear_down_called_ = false; - void VerifyModule(); + static void VerifyModule(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index c28f79ae386670ca80d603a42f6629dfd30e0bc9..cde1dcd9cd10c86107f495a92be42b57bf6a085b 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -15,978 +15,93 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include -#include -#include - -#include "tensorflow/compiler/xla/index_util.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::strings::Appendf; -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - -/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( - const Shape& expected, const Shape& actual) { - if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { - return ::testing::AssertionFailure() - << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected) - << " got: " << ShapeUtil::HumanString(actual); - } - if (ShapeUtil::IsTuple(expected)) { - if (ShapeUtil::TupleElementCount(expected) != - ShapeUtil::TupleElementCount(actual)) { - return ::testing::AssertionFailure() - << "want tuple element count: " - << ShapeUtil::TupleElementCount(expected) - << " got tuple element count: " - << ShapeUtil::TupleElementCount(actual); - } - for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - ::testing::AssertionResult result = - EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)) - << "mismatch in tuple index " << i; - if (!result) { - return result; - } - } - } else { - if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { - return ::testing::AssertionFailure() - << "want rank of: " << ShapeUtil::HumanString(expected) - << " got rank of: " << ShapeUtil::HumanString(actual); - } - if (expected.element_type() != actual.element_type()) { - return ::testing::AssertionFailure() - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - } - if (expected.dimensions_size() != actual.dimensions_size()) { - return ::testing::AssertionFailure() - << "want dimensions_size " << expected.dimensions_size() - << " got dimensions_size " << actual.dimensions_size(); - } - for (int i = 0; i < expected.dimensions_size(); ++i) { - if (expected.dimensions(i) != actual.dimensions(i)) { - return ::testing::AssertionFailure() - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); - } - } - } - return ::testing::AssertionSuccess(); -} - -/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, - const Shape& actual) { - ASSERT_TRUE(EqualShapes(expected, actual)); -} - -/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts( - const Shape& expected, const Shape& actual) { - ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); -} - -namespace { - -// Return a literal with all arrays of type FromNativeT converted to type -// ToNativeT in the given literal. -template -std::unique_ptr ConvertType(const Literal& literal) { - // First construct shape of the result. - Shape result_shape(literal.shape()); - ShapeUtil::ForEachMutableSubshape( - &result_shape, [](Shape* subshape, const ShapeIndex&) { - if (subshape->element_type() == - primitive_util::NativeToPrimitiveType()) { - subshape->set_element_type( - primitive_util::NativeToPrimitiveType()); - } - }); - auto result = MakeUnique(result_shape); - - // Then copy over the data from 'literal' converting FromNativeT values to - // ToNativeT values as necessary. - ShapeUtil::ForEachSubshape( - literal.shape(), - [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { - if (subshape.element_type() == - primitive_util::NativeToPrimitiveType()) { - auto src = literal.data(shape_index); - auto dest = result->data(shape_index); - for (int64 i = 0; i < src.size(); ++i) { - dest[i] = static_cast(src[i]); - } - } else { - TF_CHECK_OK(result->CopyFrom(literal, - /*dest_shape_index=*/shape_index, - /*src_shape_index=*/shape_index)); - } - } - }); - return result; -} - -} // namespace - -/* static */ std::unique_ptr LiteralTestUtil::ConvertBF16ToF32( - const Literal& literal) { - return ConvertType(literal); -} - -/* static */ std::unique_ptr LiteralTestUtil::ConvertF32ToBF16( - const Literal& literal) { - return ConvertType(literal); -} - namespace { -string Hostname() { - char hostname[1024]; - gethostname(hostname, sizeof hostname); - hostname[sizeof hostname - 1] = 0; - return string(hostname); -} - -// Helper function for comparing a floating point type, FloatT, bitwise equal -// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT -// -- on miscompare, a nice error message is given in the AssertionFailure. -template -::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { - auto ulhs = tensorflow::bit_cast(lhs); - auto urhs = tensorflow::bit_cast(rhs); - auto lhs_double = static_cast(lhs); - auto rhs_double = static_cast(rhs); - if (ulhs != urhs) { - return ::testing::AssertionFailure() << Printf( - "floating values are not bitwise-equal; and equality testing " - "was requested: %s=%g=%a vs %s=%g=%a", - StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, - lhs_double, StrCat(tensorflow::strings::Hex(urhs)).c_str(), - rhs_double, rhs_double); - } - return ::testing::AssertionSuccess(); -} - -// Templated comparator that specializes for float equality comparison with the -// bitwise helper above (this is the un-specialized fallback, to just use the -// default gunit implementation). -template -::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) { - if (lhs == rhs) { +// Writes the given literal to a file in the test temporary directory. +void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { + auto get_hostname = [] { + char hostname[1024]; + gethostname(hostname, sizeof hostname); + hostname[sizeof hostname - 1] = 0; + return string(hostname); + }; + int64 now_usec = tensorflow::Env::Default()->NowMicros(); + string filename = tensorflow::io::JoinPath( + tensorflow::testing::TmpDir(), + tensorflow::strings::Printf("tempfile-%s-%llx-%s", get_hostname().c_str(), + now_usec, name.c_str())); + TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, + literal.ToProto())); + LOG(ERROR) << "wrote to " << name << " file: " << filename; +} + +// Callback helper that dumps literals to temporary files in the event of a +// miscomparison. +void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, + const LiteralSlice& mismatches) { + LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " " + << literal_comparison::ToStringTruncated(expected); + LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " " + << literal_comparison::ToStringTruncated(actual); + LOG(INFO) << "Dumping literals to temp files..."; + WriteLiteralToTempFile(expected, "expected"); + WriteLiteralToTempFile(actual, "actual"); + WriteLiteralToTempFile(mismatches, "mismatches"); +} + +::testing::AssertionResult StatusToAssertion(const Status& s) { + if (s.ok()) { return ::testing::AssertionSuccess(); } - ::testing::Message msg; - msg << "Expected equality of these values:"; - msg << "\n " << lhs; - msg << "\n " << rhs; - - return ::testing::AssertionFailure() << msg; -} - -// Specializations for floating types that do bitwise comparisons when equality -// comparison is requested. -template <> -::testing::AssertionResult CompareEqual(bfloat16 lhs, bfloat16 rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(Eigen::half lhs, - Eigen::half rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(float lhs, float rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(double lhs, double rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(complex64 lhs, - complex64 rhs) { - auto res = CompareEqual(lhs.real(), rhs.real()); - if (!res) { - return res; - } - return CompareEqual(lhs.imag(), rhs.imag()); -} - -// A recursive function which iterates through every index of expected and -// actual literal and compares their values elementwise. Returns true if all -// elements are equal. -template -bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, - tensorflow::gtl::MutableArraySlice multi_index, - int64 dimension) { - if (dimension == expected.shape().dimensions_size()) { - NativeT expected_value = expected.Get(multi_index); - NativeT actual_value = actual.Get(multi_index); - ::testing::AssertionResult result = - CompareEqual(expected_value, actual_value); - return result; // Defines implicit coersion to bool. - } - - bool all_match = true; - for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { - multi_index[dimension] = i; - all_match = all_match && ExpectLiteralsEqual( - expected, actual, multi_index, dimension + 1); - } - return all_match; + return ::testing::AssertionFailure() << s.error_message(); } } // namespace -/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, - const Literal& actual, - const string& message) { - EXPECT_TRUE(Equal(expected, actual)) - << "expected:\n" - << expected.ToString() << "\n\tvs actual:\n" - << actual.ToString() - << (message.empty() ? "" : StrCat("\nmessage: ", message)); -} - -/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, - const Literal& actual) { - EXPECT_FALSE(Equal(expected, actual)); -} - -/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( - const Literal& expected, const Literal& actual) { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, expected.ToString()); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, actual.ToString()); - - AssertEqualShapes(expected.shape(), actual.shape()); - std::vector multi_index(expected.shape().dimensions_size(), 0); - bool match = false; - switch (expected.shape().element_type()) { - case PRED: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U8: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case S32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case S64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case BF16: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F16: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case C64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case TUPLE: { - bool tuple_match = true; - for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - SCOPED_TRACE(StrCat("Tuple index ", i, " in ", - ShapeUtil::HumanString(expected.shape()))); - - // Create LiteralViews of the expected and actual elements. - auto result = Equal(LiteralView::Create(expected, {i}), - LiteralView::Create(actual, {i})); - tuple_match = tuple_match ? !!result : false; - } - match = tuple_match; - break; - } - default: - LOG(FATAL) - << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " - << PrimitiveType_Name(expected.shape().element_type()); - } - ::testing::AssertionResult result = ::testing::AssertionSuccess(); - if (!match) { - result = ::testing::AssertionFailure() - << "expected: " << expected.ToString() - << "\nactual: " << actual.ToString(); - VLOG(1) << result.message(); - } - return result; -} - -namespace { - -// Gets the total element count. For tuples, this is not the count of tuple -// elements, but the sum of elements of each tuple element. -int64 RecursiveElementCount(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { - const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); - int64 total = 0; - for (int64 i = 0; i < tuple_elements; ++i) { - total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); - } - return total; - } else { - return ShapeUtil::ElementsIn(shape); - } -} - -// Calling ToString on a literal with over 100 million elements takes around -// 3 minutes. The utility of printing a literal with >1000 elements is -// questionable, especially when writing the Literal proto to disk is orders -// of magnitude faster. -string TruncateHugeLiteral(const Literal& literal) { - return RecursiveElementCount(literal.shape()) < 1000 - ? literal.ToString() - : "[TRUNCATED, Literal with more than 1000 values]"; +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( + const Shape& expected, const Shape& actual) { + return StatusToAssertion(literal_comparison::EqualShapes(expected, actual)); } -// Returns whether the actual and expected values are mismatched with respect to -// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. -template -bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { - if (relaxed_nans) { - return !std::isnan(expected) && std::isnan(actual); - } else { - return std::isnan(expected) != std::isnan(actual); +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapesAndLayouts( + const Shape& expected, const Shape& actual) { + if (expected.ShortDebugString() != actual.ShortDebugString()) { + return ::testing::AssertionFailure() + << "want: " << expected.ShortDebugString() + << " got: " << actual.ShortDebugString(); } + return ::testing::AssertionSuccess(); } -template <> -bool NanMismatch(complex64 expected, complex64 actual, - bool relaxed_nans) { - return NanMismatch(expected.real(), actual.real(), relaxed_nans) || - NanMismatch(expected.imag(), actual.imag(), relaxed_nans); -} - -template <> -bool NanMismatch(half expected, half actual, bool relaxed_nans) { - return NanMismatch(static_cast(expected), - static_cast(actual), relaxed_nans); -} - -// Converts the given floating-point value to a string. -template -string FpValueToString(NativeT value) { - return Printf("%8.4g", static_cast(value)); -} - -template <> -string FpValueToString(complex64 value) { - return Printf("%8.4g + %8.4fi", value.real(), value.imag()); -} - -// Returns the absolute value of the given floating point value. This function -// is used instead of std::abs directly in order to allow type-dependent -// implementations for NearComparator. -template -float FpAbsoluteValue(NativeT value) { - return std::abs(value); -} - -template <> -float FpAbsoluteValue(bfloat16 value) { - return FpAbsoluteValue(static_cast(value)); -} - -template <> -float FpAbsoluteValue(half value) { - return FpAbsoluteValue(static_cast(value)); -} - -// Helper class for comparing floating-point literals within an error bound. -template -class NearComparator { - public: - // Compares the two array literals elementwise and returns an assertion - // result. The assertion result is successful if all actual and expected - // elements are within the given error bound. In case of error, the assertion - // result contains a detailed error message in case of failure. - static ::testing::AssertionResult Compare(const Literal& expected, - const Literal& actual, - ErrorSpec error, - bool detailed_message) { - NearComparator comparator(expected, actual, error, - detailed_message); - return comparator.Run(); - } - - private: - // Data structure encapsulating metadata about a single element mismatch. - struct Mismatch { - NativeT actual; - NativeT expected; - float rel_error; - float abs_error; - - // The linear index of the failure within the shape. This linear index is - // from the 'actual' literal. - int64 linear_index; - - bool operator<(const Mismatch& other) const { - return rel_error < other.rel_error; - } - - string ToString(const Shape& shape) const { - return Printf( - "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", - FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), - LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex(shape, - linear_index)) - .c_str(), - rel_error, abs_error); - } - }; - - explicit NearComparator(const Literal& expected, const Literal& actual, - ErrorSpec error, bool detailed_message) - : expected_(expected), - actual_(actual), - error_(error), - detailed_message_(detailed_message), - abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}), - abs_error_buckets_(kErrorBucketBounds.size(), 0), - rel_error_buckets_(kErrorBucketBounds.size(), 0) {} - - // Runs the comparison between expected and actual literals. - ::testing::AssertionResult Run() { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, TruncateHugeLiteral(expected_)); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, TruncateHugeLiteral(actual_)); - - // If the shapes mismatch, we simply fail the expectation instead of - // printing out data, as it's a type error rather than a value error. - ::testing::AssertionResult equal_shapes = - LiteralTestUtil::EqualShapes(expected_.shape(), actual_.shape()); - if (!equal_shapes) { - return equal_shapes; - } - if (!ShapeUtil::IsArray(expected_.shape())) { - return ::testing::AssertionFailure() << "Expected array shape"; - } - - mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); - mismatches_.PopulateWithValue(false); - - CompareLiterals(); - - if (num_mismatches_ == 0) { - return ::testing::AssertionSuccess(); - } else if (!VLOG_IS_ON(1)) { - LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected_.shape()) - << " " << TruncateHugeLiteral(expected_); - LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual_.shape()) - << " " << TruncateHugeLiteral(actual_); - LOG(INFO) << "Dumping literals to temp files..."; - WriteLiteralToTempFile(expected_, "expected"); - WriteLiteralToTempFile(actual_, "actual"); - WriteLiteralToTempFile(mismatches_, "mismatches"); - } - return ::testing::AssertionFailure() << ErrorMessage(); - } - - // Insert the given absolute value into the absolute value bucket vector. The - // bounds of the buckets are given by kAbsValueBucketBounds. - void UpdateAbsValueBucket(NativeT value, bool is_mismatch) { - // Adjust the bucket containing the absolute values of the 'actual' - // elements. - const float abs_value = FpAbsoluteValue(value); - for (int i = 0; i < abs_value_buckets_.size(); ++i) { - if (i == abs_value_buckets_.size() - 1 || - (abs_value >= kAbsValueBucketBounds[i] && - abs_value < kAbsValueBucketBounds[i + 1])) { - // The first value of the pair is the count of elements in the bucket, - // the second is the count of mismatches in the bucket. - abs_value_buckets_[i].first++; - if (is_mismatch) { - abs_value_buckets_[i].second++; - } - return; - } - } - } - - // Insert the given error into the given error bucket vector. - void UpdateErrorBucket( - float error, tensorflow::gtl::MutableArraySlice error_buckets) { - CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); - for (int i = 0; i < error_buckets.size(); ++i) { - if (error >= kErrorBucketBounds[i]) { - error_buckets[i]++; - } - } - } - - // Compares the two given elements from the expected and actual literals at - // the given literal_index and keeps track of various mismatch statistics. - void CompareValues(NativeT expected, NativeT actual, int64 linear_index) { - const bool is_nan_mismatch = - NanMismatch(expected, actual, error_.relaxed_nans); - float abs_error; - float rel_error; - if (actual == expected) { - abs_error = 0; - rel_error = 0; - } else if (is_nan_mismatch) { - num_nan_mismatches_++; - // A nan mismatch is considered to have infinite error. rel_error is used - // for sorting a std::set of the top mismatchs, and a nan value here will - // result in undefined behavior because nan's do not satisfy the strict - // weak ordering requirement of std containers. - abs_error = std::numeric_limits::infinity(); - rel_error = std::numeric_limits::infinity(); - } else { - abs_error = FpAbsoluteValue(actual - expected); - rel_error = abs_error / FpAbsoluteValue(expected); - } - const bool is_abs_mismatch = abs_error > error_.abs; - const bool is_rel_mismatch = rel_error > error_.rel; - const bool is_mismatch = - is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); - - // Update the error of the relative bucket only if the *absolute* error - // bound is exceeded and vice versa. - if (is_abs_mismatch) { - num_abs_mismatches_++; - UpdateErrorBucket(rel_error, &rel_error_buckets_); - } - if (is_rel_mismatch) { - num_rel_mismatches_++; - UpdateErrorBucket(abs_error, &abs_error_buckets_); - } - - UpdateAbsValueBucket(actual, is_mismatch); - - if (!is_mismatch) { - return; - } - - num_mismatches_++; - - // Keep track of the kTopRelativeErrorCount relative error mismatches. - if (top_rel_mismatches_.size() < kTopRelativeErrorCount || - rel_error > top_rel_mismatches_.begin()->rel_error) { - Mismatch mismatch = {actual, expected, rel_error, abs_error, - linear_index}; - top_rel_mismatches_.insert(mismatch); - if (top_rel_mismatches_.size() > kTopRelativeErrorCount) { - top_rel_mismatches_.erase(top_rel_mismatches_.begin()); - } - } - - mismatches_.data()[linear_index] = true; - } - - // Compares the two literals elementwise. - void CompareLiterals() { - // Fast path optimization for the case were layouts match. - if (LayoutUtil::Equal(actual_.shape().layout(), - expected_.shape().layout())) { - tensorflow::gtl::ArraySlice expected_data = - expected_.data(); - tensorflow::gtl::ArraySlice actual_data = - actual_.data(); - const int64 len = expected_data.size(); - for (int64 i = 0; i < len; ++i) { - CompareValues(expected_data[i], actual_data[i], i); - } - return; - } - std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); - CompareLiteralsSlow(0, &multi_index); - } - - // Slow path for CompareLiterals when 'actual' and 'expected' literals have - // different layouts. In this case, multidimensional indices are constructed - // and indexed for each element. - void CompareLiteralsSlow(int64 dimension, std::vector* multi_index) { - if (dimension == multi_index->size()) { - CompareValues(expected_.Get(*multi_index), - actual_.Get(*multi_index), - IndexUtil::MultidimensionalIndexToLinearIndex( - actual_.shape(), *multi_index)); - } else { - for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) { - (*multi_index)[dimension] = i; - CompareLiteralsSlow(dimension + 1, multi_index); - } - } - } - - // Writes the given literal to a file in the test temporary directory. - void WriteLiteralToTempFile(const Literal& literal, const string& name) { - int64 now_usec = tensorflow::Env::Default()->NowMicros(); - string filename = tensorflow::io::JoinPath( - tensorflow::testing::TmpDir(), - Printf("tempfile-%s-%llx-%s", Hostname().c_str(), now_usec, - name.c_str())); - TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), - filename, literal.ToProto())); - LOG(ERROR) << "wrote to " << name << " file: " << filename; - } - - // Returns an error message string with a detailed breakdown of the - // mismatches. Called after calling Run(). - string ErrorMessage() { - string out; - int64 element_count = ShapeUtil::ElementsIn(actual_.shape()); - - auto percent_string = [](float a, float b) { - float pct = b == 0.0 ? 0.0 : 100.0 * a / b; - return Printf("%0.4f%%", pct); - }; - - Appendf(&out, - "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " - "%g, rel bound %g\n", - num_mismatches_, - percent_string(num_mismatches_, element_count).c_str(), - ShapeUtil::HumanString(actual_.shape()).c_str(), - ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); - if (num_nan_mismatches_ > 0) { - StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); - } - Appendf(&out, "Top relative error mismatches:\n"); - for (auto it = top_rel_mismatches_.rbegin(); - it != top_rel_mismatches_.rend(); ++it) { - StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); - } - - if (!detailed_message_) { - return out; - } - - StrAppend(&out, "Absolute magnitude breakdown of actual values:\n"); - CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size()); - for (int i = 0; i < abs_value_buckets_.size(); ++i) { - const int64 bucket_size = abs_value_buckets_[i].first; - const int64 bucket_mismatches = abs_value_buckets_[i].second; - string mismatch_str = bucket_mismatches > 0 - ? Printf(", mismatches %lld", bucket_mismatches) - : ""; - Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", - kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], - bucket_size, percent_string(bucket_size, element_count).c_str(), - mismatch_str.c_str()); - } - - auto print_accum_buckets = [&](const string& header, int64 total, - tensorflow::gtl::ArraySlice buckets) { - StrAppend(&out, header, ":\n"); - Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], - total - buckets[0], - percent_string(total - buckets[0], total).c_str()); - CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); - for (int i = 0; i < kErrorBucketBounds.size(); ++i) { - Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], - buckets[i], percent_string(buckets[i], total).c_str()); - } - }; - Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", - error_.abs, num_abs_mismatches_, - percent_string(num_abs_mismatches_, element_count).c_str()); - print_accum_buckets( - "Relative error breakdown of elements exceeding abs error bound", - num_abs_mismatches_, rel_error_buckets_); - Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", - error_.rel, num_rel_mismatches_, - percent_string(num_rel_mismatches_, element_count).c_str()); - print_accum_buckets( - "Absolute error breakdown of elements exceeding rel error bound", - num_rel_mismatches_, abs_error_buckets_); - return out; - } - - // 'actual' and 'expected' literals being compared. - const Literal& expected_; - const Literal& actual_; - - // The error bounds of the comparison. - ErrorSpec error_; - - // Whether to include detailed breakdown of mismatches in the error message. - bool detailed_message_; - - // Number of element element mismatches encountered so far. - int64 num_mismatches_ = 0; - - // Number of elements with a nan mismatch. - int64 num_nan_mismatches_ = 0; - - // Number of elements which exceed the absolute/relative error bound. - int64 num_abs_mismatches_ = 0; - int64 num_rel_mismatches_ = 0; - - // A Literal containing which elements did not match in the expected and - // actual literals. mismatches_ contains PREDs and is of the same sizes as - // the comparison literals. - Literal mismatches_; - - // The number of mismatches to report in the output, sorted by relative error - // magnitude. - static constexpr int64 kTopRelativeErrorCount = 5; - - // The set of mismatches with the largest relative error. The size of this set - // is bounded by kTopRelativeErrorCount. - std::multiset top_rel_mismatches_; - - // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the - // bounds of these buckets. abs_value_buckets_ contains a pair for each - // bucket: the element count and failure count. - static constexpr std::array kAbsValueBucketBounds = { - 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits::infinity()}; - std::vector> abs_value_buckets_; - - // Buckets for relative and absolute errors. The relative error buckets only - // contains those elements which exceed the *absolute* error bound, and vice - // versa. This makes it easy to see the effect of adjusting the relative (or - // absolute) error bound on the success of the comparison. kErrorBucketBounds - // are the lower bounds of the buckets in both vectors. The error buckets are - // a cumulative distribution so an error value may appear in more than one - // bucket. For example an error value of 0.003 may appear in the buckets - // bounded by 0.01, 0.1, and 1.0. - static constexpr std::array kErrorBucketBounds = {0.0001, 0.001, - 0.01, 0.1, 1}; - std::vector abs_error_buckets_; - std::vector rel_error_buckets_; -}; - -template -constexpr std::array NearComparator::kAbsValueBucketBounds; -template -constexpr std::array NearComparator::kErrorBucketBounds; - -// Helper function for comparing two literals for nearness. Handles tuple-shapes -// via recursion. shape_index is the ShapeIndex of expected (or actual) -// currently being compared. -::testing::AssertionResult NearHelper(const Literal& expected, - const Literal& actual, - const ErrorSpec& error, - bool detailed_message, - const ShapeIndex& shape_index) { - ::testing::AssertionResult err = - LiteralTestUtil::EqualShapes(expected.shape(), actual.shape()); - if (!err) { - return err; - } - - if (ShapeUtil::IsTuple(expected.shape())) { - for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - const auto expected_element = LiteralView::Create(expected, {i}); - const auto actual_element = LiteralView::Create(actual, {i}); - ShapeIndex element_index = shape_index; - element_index.push_back(i); - ::testing::AssertionResult res = - NearHelper(expected_element, actual_element, error, detailed_message, - element_index); - if (!res) { - string err_message = - Printf("\nArray at shape index %s%s", - element_index.ToString().c_str(), res.message()); - if (err) { - err = ::testing::AssertionFailure() << err_message; - } else { - err << err_message; - } - } - } - if (!err && shape_index.empty()) { - // Emit a top-level error message containing the top-level shape in case - // of mismatch. - int64 total_elements = RecursiveElementCount(actual.shape()); - err = ::testing::AssertionFailure() - << Printf("\nMismatches in shape %s (%lld elements):\n%s", - ShapeUtil::HumanString(actual.shape()).c_str(), - total_elements, err.message()); - } - return err; - } - - if (ShapeUtil::ElementIsFloating(expected.shape()) || - ShapeUtil::ElementIsComplex(expected.shape())) { - switch (expected.shape().element_type()) { - case BF16: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case F16: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case F32: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case F64: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case C64: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - default: - LOG(FATAL) << "Unsupported primitive type in near comparator: " - << PrimitiveType_Name(expected.shape().element_type()) - << ". Must be floating-point type."; - } - } - - // Non-floating point literal. - return LiteralTestUtil::Equal(expected, actual); +/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( + const LiteralSlice& expected, const LiteralSlice& actual) { + return StatusToAssertion(literal_comparison::Equal(expected, actual)); } -} // namespace - /* static */ ::testing::AssertionResult LiteralTestUtil::Near( - const Literal& expected, const Literal& actual, const ErrorSpec& error, - bool detailed_message) { - return NearHelper(expected, actual, error, detailed_message, - /*shape_index=*/{}); -} - -/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, - const Literal& actual, - const ErrorSpec& error, - const string& message) { - ::testing::AssertionResult res = - Near(expected, actual, error, /*detailed_message=*/false); - if (!res) { - res << "Expected: " << TruncateHugeLiteral(expected) << "\n"; - res << "Actual: " << TruncateHugeLiteral(actual) << "\n"; - if (!message.empty()) { - res << StrCat("\nmessage: ", message); - } - } - EXPECT_TRUE(res); + const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error_spec, bool detailed_message) { + return StatusToAssertion(literal_comparison::Near( + expected, actual, error_spec, detailed_message, &OnMiscompare)); } -/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( - const Literal& expected, const Literal& actual, +/* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( + const LiteralSlice& expected, const LiteralSlice& actual, const tensorflow::gtl::optional& error) { if (error.has_value()) { VLOG(1) << "Expects near"; - return Near(expected, actual, *error); + return StatusToAssertion(literal_comparison::Near( + expected, actual, *error, /*detailed_message=*/false, &OnMiscompare)); } VLOG(1) << "Expects equal"; - return Equal(expected, actual); -} - -/*static*/ void LiteralTestUtil::ExpectNearOrEqual( - const Literal& expected, const Literal& actual, - const tensorflow::gtl::optional& error) { - EXPECT_TRUE(NearOrEqual(expected, actual, error)); -} - -/* static */ string LiteralTestUtil::MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index) { - return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); -} - -/* static */ std::unique_ptr LiteralTestUtil::Reshape( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, const Literal& literal) { - int64 new_num_elements = 1; - for (int64 i = 0; i < new_dimensions.size(); ++i) { - new_num_elements *= new_dimensions[i]; - } - CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); - CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - - auto new_literal = MakeUnique( - ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); - - // Create a new shape with the given minor-to-major layout. This shape is used - // solely for converting linear address to multi-dimensional addresses when - // writing elements to the new literal. - Shape shape_with_layout = new_literal->shape(); - *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); - - // Copy data into new literal, element-by-element. - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { - std::vector from_multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); - std::vector to_multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); - switch (literal.shape().element_type()) { - case PRED: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U8: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case S32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case S64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case F32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case F64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case C64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - default: - LOG(FATAL) << "Unhandled primitive element type: " - << PrimitiveType_Name(literal.shape().element_type()); - } - } - - return new_literal; + return StatusToAssertion(literal_comparison::Equal(expected, actual)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index a755568c0f098e15512bd1d3720269c867bc9c49..d1b8a6cf0b2552f1b7d95a2560d502da14ddc39a 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/error_spec.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -38,282 +39,190 @@ limitations under the License. namespace xla { -// Structure describing permissible absolute and relative error bounds. -struct ErrorSpec { - explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) - : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} - - float abs; // Absolute error bound. - float rel; // Relative error bound. - - // If relaxed_nans is true then any result is valid if we are expecting NaNs. - // In effect, this allows the tested operation to produce incorrect results - // for inputs outside its mathematical domain. - bool relaxed_nans; -}; - // Utility class for making expectations/assertions related to XLA literals. class LiteralTestUtil { public: // Asserts that the given shapes have the same rank, dimension sizes, and // primitive types. - static ::testing::AssertionResult EqualShapes(const Shape& expected, - const Shape& actual); - static void AssertEqualShapes(const Shape& expected, const Shape& actual); + static ::testing::AssertionResult EqualShapes( + const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; // Asserts that the provided shapes are equal as defined in AssertEqualShapes // and that they have the same layout. - static void AssertEqualShapesAndLayouts(const Shape& expected, - const Shape& actual); - - // If the given literal's data type is bfloat16, converts it to a float - // literal; otherwise, returns a copy of it. If the literal is a tuple, - // recursively converts its elements. - static std::unique_ptr ConvertBF16ToF32(const Literal& bf16_literal); - - // If the given literal's data type is float, converts it to a bfloat16 - // literal; otherwise, returns a copy of it. If the literal is a tuple, - // recursively converts its elements. - static std::unique_ptr ConvertF32ToBF16(const Literal& f32_literal); - - // Asserts that the expected and actual literals are (bitwise) equal for all - // elements in the literal. Also, asserts that the rank, dimensions sizes, and - // primitive type are equal. - static ::testing::AssertionResult Equal( - const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; + static ::testing::AssertionResult EqualShapesAndLayouts( + const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; - // Expects that expected and actual are Equal. - static void ExpectEqual(const Literal& expected, const Literal& actual, - const string& message = ""); - - // Expects that expected and actual are Not Equal. - static void ExpectNotEqual(const Literal& expected, const Literal& actual); + static ::testing::AssertionResult Equal(const LiteralSlice& expected, + const LiteralSlice& actual) + TF_MUST_USE_RESULT; // Asserts the given literal are (bitwise) equal to given expected values. template - static void ExpectR0Equal(NativeT expected, const Literal& actual); + static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual); + template static void ExpectR1Equal(tensorflow::gtl::ArraySlice expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR2Equal( std::initializer_list> expected, - const Literal& actual); + const LiteralSlice& actual); + template static void ExpectR3Equal( std::initializer_list< std::initializer_list>> expected, - const Literal& actual); + const LiteralSlice& actual); // Asserts the given literal are (bitwise) equal to given array. template static void ExpectR2EqualArray2D(const Array2D& expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR3EqualArray3D(const Array3D& expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR4EqualArray4D(const Array4D& expected, - const Literal& actual); + const LiteralSlice& actual); - // Asserts that the expected and actual literals are within the given error - // bound for all elements. Also, asserts that the rank, dimensions sizes, and - // bounds are equivalent. + // Decorates literal_comparison::Near() with an AssertionResult return type. // - // Tuples are matched recursively. When comparing tensors of - // non-floating-point type, checks for exact equality, ignoring the ErrorSpec. - // - // If the shape of the literals is neither a complex/floating-point tensor nor - // a tuple which contains a complex/floating-point tensor, Near() is - // equivalent to Equal(). We don't raise an error in this case, because we - // want to allow callers to call Near() even if they have no preconceptions - // about the shapes being compared. - // - // If detailed_message is true, then the error message in the assertion result - // will contain a more detailed breakdown of mismatches. + // See comment on literal_comparison::Near(). static ::testing::AssertionResult Near( - const Literal& expected, const Literal& actual, const ErrorSpec& error, + const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error_spec, bool detailed_message = false) TF_MUST_USE_RESULT; - // Expects expected and actual to be Near with the given error. - static void ExpectNear(const Literal& expected, const Literal& actual, - const ErrorSpec& error, const string& message = ""); - // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values. template - static void ExpectR0Near(NativeT expected, const Literal& actual, + static void ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR1Near(tensorflow::gtl::ArraySlice expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR3Near( std::initializer_list< std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR4Near( std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); // Asserts the given literal are within the given error bound to the given // array. Only supported for floating point values. template static void ExpectR2NearArray2D(const Array2D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR3NearArray3D(const Array3D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR4NearArray4D(const Array4D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); // If the error spec is given, returns whether the expected and the actual are // within the error bound; otherwise, returns whether they are equal. Tuples // will be compared recursively. static ::testing::AssertionResult NearOrEqual( - const Literal& expected, const Literal& actual, + const LiteralSlice& expected, const LiteralSlice& actual, const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; - // If the error spec is given, expects the expected and the actual to be near; - // otherwise, expects them to be equal. Tuples will be compared recursively. - static void ExpectNearOrEqual( - const Literal& expected, const Literal& actual, - const tensorflow::gtl::optional& error); - - // Returns a multi-dimensional index as a string. For example: '{7, 8}' will - // be returned for a 2-dimensional index with dimension 0 index equal to 7, - // dimension 1 equal to 8. - static string MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index); - - // Creates a literal with a new shape with the given new dimensions using the - // data in the given input literal. For reshaping purposes the (flat) data - // buffer of the input literal is assumed to have the given minor_to_major - // layout order. - static std::unique_ptr Reshape( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const Literal& literal); - - // Creates a literal with the supplied shape, and uses the provided value - // generator to populate the literal's values. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, - const std::function)>& generator); - - // Creates a literal with the supplied shape, and initializes the literal - // values using a normal distribution with given mean and stddev standard - // deviation, and using the engine as entropy generator. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, typename E, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev); - - // Creates a literal with the supplied shape, and initializes the literal - // values using a normal distribution with given mean and stddev standard - // deviation. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, T mean, T stddev); - private: TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); }; template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR0(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( - tensorflow::gtl::ArraySlice expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR1(expected), actual); + tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR2(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR2(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3Equal( std::initializer_list>> expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR3(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( - const Array2D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual); + const Array2D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( - const Array3D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual); + const Array3D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( - const Array4D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual); + const Array4D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR0(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( - tensorflow::gtl::ArraySlice expected, const Literal& actual, + tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR1(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR2(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR2(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3Near( std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR3(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR3(expected), actual, error)); } template @@ -321,63 +230,29 @@ template std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR4(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( - const Array2D& expected, const Literal& actual, + const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( - const Array3D& expected, const Literal& actual, + const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( - const Array4D& expected, const Literal& actual, + const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral( - const Shape& shape, - const std::function)>& generator) { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - TF_RET_CHECK(shape.element_type() == type); - std::unique_ptr literal = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(literal.get()->Populate( - [&](tensorflow::gtl::ArraySlice indexes) { - return generator(indexes); - })); - return std::move(literal); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, - T stddev) { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - std::normal_distribution generator(mean, stddev); - return CreateRandomLiteral( - shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { - return generator(*engine); - }); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { - std::minstd_rand0 engine; - return CreateRandomLiteral(shape, &engine, mean, stddev); + EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 9d619a77c7e8d6398b559e8f562cd7f8194e0811..bbac7285aefbb1f028fad152e4b7fe6af01e9f6d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -34,7 +34,7 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { std::unique_ptr literal = Literal::MakeTuple({ Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), }); - LiteralTestUtil::ExpectEqual(*literal, *literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); } TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { @@ -97,6 +97,15 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { } } +TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { + auto expected = Literal::CreateR1({1, 2, 3}); + auto actual = Literal::CreateR1({4, 5, 6}); + ::testing::AssertionResult result = + LiteralTestUtil::Equal(*expected, *actual); + EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); + EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}")); +} + TEST(LiteralTestUtilTest, NearComparatorR1) { auto a = Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 2f46ee0be216d7dabf1c476d3cfb7d528f8ab6a4..082bc34136e004795ce300c66591758f47c665fe 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -124,8 +124,7 @@ class LLVMCompilerTest : public ::testing::Test { static std::unique_ptr CreateNewModule() { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return MakeUnique(TestName(), VersionedComputationHandle(), - config); + return MakeUnique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 44c6811df84f49b6c1b24c11449939e2d375a9d1..96858c00d6bbe59b673a34e7d5ca261756709596 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -210,12 +210,12 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {1})); + LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -239,16 +239,16 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 0})); + LiteralSlice(*result_literal, {0, 0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {0, 1})); + LiteralSlice(*result_literal, {0, 1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 2})); + LiteralSlice(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -274,9 +274,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -321,9 +321,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( {{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralView::Create(*result_literal, {0})); + LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1})); + {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -361,9 +361,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0})); + {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1})); + {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -391,16 +391,16 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); LiteralTestUtil::ExpectR2Equal( {{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralView::Create(*result_0_literal, {0})); + LiteralSlice(*result_0_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1})); + {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); LiteralTestUtil::ExpectR2Equal( - {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0})); + {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1})); + {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -447,7 +447,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}), + {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); } } @@ -502,7 +502,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}), + i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}), error_spec_); } } @@ -548,7 +548,7 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { index.push_back(0); } LiteralTestUtil::ExpectR0Equal( - 165.0, LiteralView::Create(*result_literal, index)); + 165.0, LiteralSlice(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -754,9 +754,9 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR1Equal( - {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0})); + {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1})); + {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index e859b3059eea86b362443c3269f99ccae941dfe2..88797a7d0a7d0567b3a380c5fb1ad0c0ee875587 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -35,9 +35,9 @@ namespace xla { /* static */ TestAllocator* LocalClientTestBase::allocator_; -StatusOr TestAllocator::Allocate(int device_ordinal, - uint64 size, - bool retry_on_failure) { +StatusOr TestAllocator::Allocate(int device_ordinal, + uint64 size, + bool retry_on_failure) { VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")"; { tensorflow::mutex_lock lock(count_mutex_); @@ -48,8 +48,7 @@ StatusOr TestAllocator::Allocate(int device_ordinal, retry_on_failure); } -tensorflow::Status TestAllocator::Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) { +Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { VLOG(2) << "Deallocate(" << device_ordinal << ")"; { tensorflow::mutex_lock lock(count_mutex_); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 3bbb760c806412a671bc2502846e123e2582fd16..258226523d830b40ecaa761df95988dc90f5ca47 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -46,10 +46,9 @@ class TestAllocator : public StreamExecutorMemoryAllocator { platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) { } - StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; - tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) override; + StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // Return the number of allocations that have been performed. int64 allocation_count() const; diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 7df45bebebdd3eb2e71f27d831a8e2ac9e3b5f7c..3975e9125703ee081d4e84fa8bd27fcbe483ac34 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -488,10 +488,9 @@ TEST_F(MapTest, MapOperantionWithBuildError) { StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); - EXPECT_THAT( - computation_status.status().ToString(), - ::testing::HasSubstr("error from: ErrorAdd: Binary op BINOP_ADD with " - "different element types: f32[] and u16[]")); + EXPECT_THAT(computation_status.status().ToString(), + ::testing::HasSubstr("error from: ErrorAdd: Binary op add with " + "different element types: f32[] and u16[]")); } // MapTest disables inline and algsimp. MapTestWithFullOpt runs all diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 464cc012140d4838de88c5bf5b3b2f1372c2c19b..27fd36e06acdc589f3a84ad561164e4a33b93506 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 0a603f4954badd12adf3144320789a5edd0d9c6c..41f723edf1ff3518686231f31b61b64291b1f6bf 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -108,7 +107,7 @@ class MultiOutputFusionTest : public HloTestBase { expect.PopulateWithValue(size * 1.5f * 3.5f); auto actual = ExecuteAndTransfer( std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); - LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } void RunTest1D(bool manual_fusion, int size) { @@ -168,7 +167,7 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect = std::move(*Literal::CreateR1({size * 1.5f * 3.5f})); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); - LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } }; @@ -211,5 +210,309 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); } +XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { + const char* testcase = R"( + HloModule m + + fused_computation { + p = f32[4] parameter(0) + multiply = f32[4] multiply(p, p) + less-than = pred[4] less-than(p, multiply) + ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply) + } + + ENTRY PredFloatMOF { + p0 = f32[4] parameter(0) + fusion = (pred[4], f32[4]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[4] get-tuple-element(fusion), index=0 + gte1 = f32[4] get-tuple-element(fusion), index=1 + const = f32[4] constant({0, 0, 0, 0}) + ROOT select = f32[4] select(gte0, gte1, const) + })"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::CreateR1({0.0, 4.0, 9.0, 1.0}))); +} + +XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { + const char* testcase = R"( + HloModule m + + fused_computation { + p = f32[] parameter(0) + multiply = f32[] multiply(p, p) + less-than = pred[] less-than(p, multiply) + ROOT tuple = (pred[], f32[]) tuple(less-than, multiply) + } + + map_computation { + p0 = f32[] parameter(0) + fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[] get-tuple-element(fusion), index=0 + gte1 = f32[] get-tuple-element(fusion), index=1 + const = f32[] constant(0) + ROOT select = f32[] select(gte0, gte1, const) + } + + ENTRY MapMOF { + p1 = f32[3] parameter(0) + ROOT map = f32[3] map(p1), to_apply=map_computation + })"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR1({1.0, 2.0, 3.0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::CreateR1({0.0, 4.0, 9.0}))); +} + +const char* const kScalarOps = R"( + HloModule m + + Add { + lhsadd = f32[] parameter(0) + rhsadd = f32[] parameter(1) + ROOT add = f32[] add(lhsadd, rhsadd) + } + + Max { + lhsmax = f32[] parameter(0) + rhsmax = f32[] parameter(1) + ROOT max = f32[] maximum(lhsmax, rhsmax) + } +)"; + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned(Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned( + Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR2({{25, 36}, {49, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(1.17549e-38) + r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max + r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add + ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), + Literal::CreateR1({36, 64}), + Literal::CreateR1({66, 138})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) + tuple(p0, r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned( + Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), + Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) + tuple(r1, mul, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned( + Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + Literal::CreateR2({{25, 36}, {49, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + mul2 = f32[2,2,2]{2,1,0} multiply(p0, c1) + ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) + tuple(r1, mul, mul2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned( + Literal::CreateR1({14, 22}), + Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + Literal::CreateR3( + {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + init1 = f32[] parameter(1) + init2 = f32[] parameter(2) + r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add + r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + i = f32[] parameter(1) + j = f32[] parameter(2) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto init1 = Literal::CreateR0(5); + auto init2 = Literal::CreateR0(6); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + Execute(std::move(module), {param.get(), init1.get(), init2.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned( + Literal::CreateR2({{167, 172}, {176, 180}}), + Literal::CreateR2({{6, 6}, {6, 8}})))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 97dab860c06bddb2a0ffd45e48c4912c5f55d574..838f1b4e2f0f0e0871ec717bdeefcbbc653397e3 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" @@ -161,7 +160,7 @@ XLA_TEST_F(ParamsTest, MissingParameter) { auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); auto computation_status = builder.Build(); - ASSERT_NE(computation_status.status(), tensorflow::Status::OK()); + ASSERT_NE(computation_status.status(), Status::OK()); } XLA_TEST_F(ParamsTest, UnusedParameter) { diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 29a4f75001c688f2215745ab913df68bf2f62b76..1a2de6937c3e134852a730f62f7b56417cf49b28 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -273,11 +273,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options_)); } - LiteralTestUtil::ExpectEqual(*result1, *result2); - LiteralTestUtil::ExpectEqual(*result1, *result3); - LiteralTestUtil::ExpectNotEqual(*result1, *result4); - LiteralTestUtil::ExpectNotEqual(*result4, *result5); - LiteralTestUtil::ExpectNotEqual(*result5, *result6); + EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6)); } XLA_TEST_F(PrngTest, TenValuesN01) { diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index c0a2c0ca4cb8414e0771a541b9f963f9aedc8376..9052b188ed09a715b6ad7c3a40dc853d02cdd70c 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -73,7 +73,7 @@ ENTRY reduce.1 { } )"; - return tools::Parse(hlo_string); + return ParseHloString(hlo_string); } // TODO(b/72454718): XLA:GPU does not support executing code compiled without diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index bcc05c2d41d8439b021cdf6533b5ca87c19aec1f..d671d40456a276a44b462f390c95aa4af301263a 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 10a3da3a387641ec45baf02d15790e32371601fa..266760e8202fddc48792ac66dda334255e428808 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -356,12 +356,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - std::unique_ptr arg_literal = Literal::CreateFromShape(shape); - auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 1.0f; - }; - TF_EXPECT_OK(arg_literal->Populate(generator)); - + auto arg_literal = MakeUnique(shape); + arg_literal->PopulateWithValue(1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); Padding padding = Padding::kValid; @@ -371,13 +367,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - std::unique_ptr expected = Literal::CreateFromShape(result_shape); - auto out_generator = - [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 27.0f; - }; - TF_EXPECT_OK(expected->Populate(out_generator)); - + auto expected = MakeUnique(result_shape); + expected->PopulateWithValue(27.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -1348,7 +1339,7 @@ INSTANTIATE_TEST_CASE_P( class ReduceWindowTextTest : public HloTestBase {}; TEST_F(ReduceWindowTextTest, R2General256x384) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1365,7 +1356,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1382,7 +1373,7 @@ ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window= } TEST_F(ReduceWindowTextTest, R2General2x5) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1399,7 +1390,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1417,7 +1408,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R3Window mul { lhs = f32[] parameter(0) @@ -1435,7 +1426,7 @@ ENTRY R3Window { } TEST_F(HloTestBase, ReduceWindowIdentity) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule ReduceWindowIdentity identity.pad_to_reduce_window { param0 = f32[] parameter(0) @@ -1444,7 +1435,26 @@ identity.pad_to_reduce_window { ENTRY reduce-window-identity { operand = f32[1,32,64]{2,1,0} parameter(0) constant.4466 = f32[] constant(0) - ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window + ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + +TEST_F(HloTestBase, ReduceWindowS32) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] { + %param0 = s32[] parameter(0) + ROOT %param1 = s32[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { + %parameter.0 = s32[81,8]{1,0} parameter(0) + %parameter.1 = s32[] parameter(1) + ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window } )"; diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index 5ebd5268992846e80dcce2675f8e92038e190ecf..da1b588ec41cef711412367e89b2a9b1029bca71 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index d7462d581b8596dc43b81b0162b3f5020cebb546..a4580cd71d46ad0a0186eddd51291f9c322b6f49 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -656,9 +656,9 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { std::unique_ptr expected = Literal::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = LiteralTestUtil::ConvertF32ToBF16(*expected); + expected = Literal::ConvertF32ToBF16(*expected); } - LiteralTestUtil::ExpectEqual(*expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { @@ -731,7 +731,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = - LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); + Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -753,7 +753,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = - LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); + Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -817,7 +817,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal); + auto expected = Literal::ConvertF32ToBF16(*input_literal); EXPECT_EQ(expected->data(), output_literal->data()); } else { EXPECT_EQ(input_literal->data(), output_literal->data()); @@ -886,7 +886,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -915,7 +915,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -944,7 +944,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -974,7 +974,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -1003,7 +1003,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) + Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) ->Relayout(input_literal->shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 8cbfcc6f5c4272706a0f9fd809041516bf32432b..7cfca781acda15879075f4386c2096e537877aac 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -100,7 +100,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); - LiteralTestUtil::ExpectEqual(*round_tripped, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { @@ -135,7 +135,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); - LiteralTestUtil::ExpectEqual(*round_tripped, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index 32db45f8a66266712ba4091c2aa6368f0b822bd2..f334a8c1318a59bbfdd27dd1a63ed162600089ce 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -41,7 +41,7 @@ class RoundTripTransferTest : public ClientLibraryTestBase { client_->TransferToServer(original).ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data).ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(original, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(original, *result)); } }; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index f35bc43a4952137b4b6c94c771819e0514d4228f..308d3fc78a51e63c0e3db8c0cda18caf11f665bd 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -390,7 +390,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { &execution_options_) .ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(dividend / divisor); - LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } } @@ -431,7 +431,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { &execution_options_) .ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(dividend % divisor); - LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } } diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 3d694a9c3fe894107c3b0a8fc2e5d07310cb476c..72707f224446c7585d1d90ac6681a7b38c41d5f1 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 52195db2aa74710b901dd7744a670764a034e96b..5653bf11a7364bf9ed79bcb6b53f7db31f454803 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -197,9 +197,10 @@ class SliceR1Test : public ClientLibraryTestBase, // vector. tensorflow::gtl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); + auto literal = Literal::CreateR1(input); XlaBuilder builder(TestName()); - auto original = builder.ConstantR1(input); + auto original = builder.Parameter(0, literal->shape(), "p0"); builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); @@ -210,7 +211,9 @@ class SliceR1Test : public ClientLibraryTestBase, expected.push_back(i); } - ComputeAndCompareR1(&builder, expected, {}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, + client_->TransferToServer(*literal)); + ComputeAndCompareR1(&builder, expected, {arg.get()}); } }; @@ -365,15 +368,18 @@ XLA_TEST_P(SliceR2Test, DoIt) { const R2Spec& spec = GetParam(); Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); + auto literal = Literal::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(spec.layout)); + auto a = builder.Parameter(0, literal->shape(), "p0"); builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, + client_->TransferToServer(*literal)); std::unique_ptr> expected = ReferenceUtil::Slice2D( input, spec.slice_starts, spec.slice_limits, spec.slice_strides); - ComputeAndCompareR2(&builder, *expected, {}); + ComputeAndCompareR2(&builder, *expected, {arg.get()}); } INSTANTIATE_TEST_CASE_P( @@ -453,7 +459,7 @@ class SliceR4Test : public ClientLibraryTestBase, void Run(const R4Spec& spec) { Array4D values(spec.input_dims[0], spec.input_dims[1], spec.input_dims[2], spec.input_dims[3]); - values.FillRandom(3.14f); + values.FillIota(3.14159); auto expected = ReferenceUtil::Slice4D( values, spec.slice_starts, spec.slice_limits, spec.slice_strides); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 810cc25f1b5b1199984a3229909a70f9548c7dd2..dd7c541733634213606b5a7983b59bb1f14bf75c 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -26,6 +26,7 @@ namespace { template void PopulateWithRandomFloatingPointDataImpl(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); // Create uniform numbers between 1 and 1.125 to avoid creating denormal @@ -59,12 +60,14 @@ void PopulateWithRandomFloatingPointDataImpl(Literal* literal, template void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl(literal, engine); } template <> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl(literal, engine); } @@ -73,6 +76,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal, template <> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), BF16); std::uniform_real_distribution generator(-0.9f, 1.0f); TF_CHECK_OK(literal->Populate( @@ -84,6 +88,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal, template void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); std::uniform_int_distribution generator( @@ -107,7 +112,10 @@ StatusOr> MakeFakeLiteralInternal( } return Literal::MakeTupleOwned(std::move(elements)); } - std::unique_ptr literal = Literal::CreateFromShape(shape); + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } + auto literal = MakeUnique(shape); switch (shape.element_type()) { case BF16: PopulateWithRandomFloatingPointData(literal.get(), engine); @@ -201,11 +209,13 @@ std::unique_ptr MakeRandomNonwrappingSliceIndex( std::minstd_rand0* engine) { const int64 rank = ShapeUtil::Rank(input_shape); std::vector start_indices(rank); - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution generator(0, upper_bound); - start_indices[i] = generator(*engine); + if (engine != nullptr) { + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution generator(0, upper_bound); + start_indices[i] = generator(*engine); + } } return Literal::CreateR1(start_indices); } @@ -321,20 +331,21 @@ StatusOr> MakeConstrainedArgument( } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape) { - std::minstd_rand0 engine; - return MakeFakeLiteralInternal(shape, &engine); +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random) { + auto engine = pseudo_random ? MakeUnique() : nullptr; + return MakeFakeLiteralInternal(shape, engine.get()); } StatusOr>> MakeFakeArguments( - HloModule* const module) { + HloModule* const module, bool pseudo_random) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::minstd_rand0 engine; + auto engine = pseudo_random ? MakeUnique() : nullptr; std::vector> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN( - arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine)); + TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( + *dataflow, *params[i], engine.get())); } return std::move(arguments); } diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index f483cdebea5c7c8a43e73ab57748a93c97bb78d7..a8689f64981569ceb7c8a712f8ece00c99e8cf2d 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -55,16 +55,28 @@ class PseudorandomGenerator { }; // Generates fake data in a literal of the given shape, or returns an error -// status if the element type is currently unhandled for fake data generation. -StatusOr> MakeFakeLiteral(const Shape& shape); +// status if the element type is currently unhandled for fake data +// generation. See below for documentation of pseudo_random. +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. // // Will handle special cases such as making sure that indices used for dynamic // slices are bounded, reduces that call adds use 0 as an init value, etc. +// +// If pseudo_random is true, the generated numbers will be generated +// deterministically in a pseudo random way unless the values are constrated to +// be e.g. init values as above. If pseudo_random is false, the returned values +// will be generated in a faster way that yields less interesting data, e.g. the +// values may all be just the same value. +// +// TODO(b/79942829): Make interesting argument generation fast enough that using +// pseudo_random does not save any noticeable amount of time so that the +// parameter can be removed. StatusOr>> MakeFakeArguments( - HloModule* const module); + HloModule* const module, bool pseudo_random = true); // Check that a given module satisfies various constraints before trying to // execute it. diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4585244ce81c14ab6d4d629bb7d208d73c82248d --- /dev/null +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class TokenHloTest : public HloTestBase {}; + +// TODO(b/79770375): Compile, not just verify the HLO module when the backends +// support kGenerateToken. +XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, TokenTree) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateGenerateToken({token0, token0, token1, token2})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 1 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateParameter( + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}), + "param")); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenRoot) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Entry root is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction(HloInstruction::CreateGenerateToken({param})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr( + "Operands of token instructions must be TOKEN types")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index e2067bc1b835a946fc56801cbf227e05ef0686b4..0063e7ad415e9b6718c164f415ced6fb76cbf44a 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -175,7 +175,7 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { @@ -189,7 +189,7 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { @@ -209,7 +209,7 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { @@ -224,7 +224,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { @@ -243,7 +243,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index 59ce23d0247b58c6aebc2b5a65453157c1ca15ff..fe1e3da7eca00e128377e6e56af877868aafa836 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 5c287bac6a7cab5a3c2642971a5a67070ee56c72..41189231b90e842292830a932cf381af60456d4c 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" @@ -496,7 +495,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = Literal::CreateFromShape(sum->shape()); + auto prod = MakeUnique(sum->shape()); ASSERT_TRUE(prod->Populate( [&sum](tensorflow::gtl::ArraySlice indexes) { return sum->Get(indexes) * @@ -515,7 +514,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { class TupleHloTest : public HloTestBase {}; // Disabled on the interpreter because bitcast doesn't exist on the interpreter. -TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { +XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { const char* testcase = R"( HloModule m diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 50c8766f2e3976c7077046283ab3b3e762622fc5..c3abe22797f5eaa76ced2ad8534bd68c32983e60 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -84,6 +84,11 @@ int UnaryOpTest::inf() { return 2147483647; } +template <> +int64 UnaryOpTest::inf() { + return 0x7FFFFFFFFFFFFFFFl; +} + template <> void UnaryOpTest::AbsTestHelper() { XlaBuilder builder(TestName()); @@ -176,6 +181,7 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { XLA_TEST_F(UnaryOpTest, SignTestR1) { SignTestHelper(); + SignTestHelper(); SignTestHelper(); SignTestHelper(); } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 7944b5132f3d11cf84488acbd920cc98c084072a..3c9a01653c67203cbc962a3d3d967142f7a2102c 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -84,8 +84,8 @@ Status ParseOneProfileOutputLine( string match_percentage = "\\d+\\.\\d\\d%"; string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; string match_usecs = "([0-9.]+) usec"; - string match_flops = "([^ ]+)"; - string match_trops = "([^ ]+)"; + string match_flops = "([^ ]*)"; + string match_trops = "([^ ]*)"; string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 6e3061b78a554f028b2ffae2e0590d91a4fe48e2..373c0d2d8d8ab05dec11e51f265d41b91e7920bf 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -30,7 +30,7 @@ limitations under the License. namespace xla { -/* static */ tensorflow::Status TextLiteralWriter::WriteToPath( +/* static */ Status TextLiteralWriter::WriteToPath( const Literal& literal, tensorflow::StringPiece path) { std::unique_ptr f; auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f); @@ -43,7 +43,7 @@ namespace xla { return s; } - tensorflow::Status status; + Status status; tensorflow::WritableFile* f_ptr = f.get(); literal.EachCellAsString( [f_ptr, &status](tensorflow::gtl::ArraySlice indices, diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 7375493f4309c9bf75fc9d724626267dff7ce5ed..0a1235b5e04675da0f412bafab6c4ecf04367787 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -37,8 +37,8 @@ namespace xla { // This should be readable by xla::TextLiteralReader. class TextLiteralWriter { public: - static tensorflow::Status WriteToPath(const Literal& literal, - tensorflow::StringPiece path); + static Status WriteToPath(const Literal& literal, + tensorflow::StringPiece path); private: TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 78ab2dccafc37aa4f93da0b8d5b39a779ddd5db8..e4a052c8f1c0009619c3a94606f6384d04006e4e 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -36,11 +36,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -63,10 +62,9 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -84,12 +82,12 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -138,7 +136,7 @@ tf_cc_binary( deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -165,12 +163,10 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -184,12 +180,11 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -202,13 +197,12 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/tools/convert_computation.cc b/tensorflow/compiler/xla/tools/convert_computation.cc index fe03a6e7bdfe99877c250fe1ae22beee4c8018a2..14d01b5bfb067cc39abc4d6e0605007624b6e0ae 100644 --- a/tensorflow/compiler/xla/tools/convert_computation.cc +++ b/tensorflow/compiler/xla/tools/convert_computation.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/env.h" @@ -33,7 +33,7 @@ namespace xla { namespace tools { void RealMain(const string& mode, const string& path) { - SessionModule module; + HloSnapshot module; tensorflow::Env* env = tensorflow::Env::Default(); if (mode == "txt2bin") { TF_CHECK_OK(tensorflow::ReadTextProto(env, path, &module)); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index 21ae8583d7cd3343230dcaff7dc17456e9e3e702..befb55453777dce30af89bcaad2ffe1647097576 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -17,7 +17,7 @@ limitations under the License. // // Dumps a graphviz URL for a snapshot computation to the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The GraphViz URL is placed into the log stderr, whereas computation @@ -30,11 +30,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,10 +48,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); ComputationStats stats = diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index b82f1c81c84b487c1661af5267b9123da97bb107..cfb8f37487d6499b803438a135be54524fcf17d2 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -21,11 +21,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -66,16 +65,16 @@ void RealMain(tensorflow::gtl::ArraySlice args) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); std::unique_ptr program_shape = client->GetComputationShape(computation).ConsumeValueOrDie(); @@ -89,8 +88,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 05c0fdf97d27c09eb2bbb0f265b5b2a5982ca7b1..5dd5150be339846d0775880931f615b92c5b08d8 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -19,11 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,16 +38,16 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); if (compile) { std::unique_ptr program_shape = @@ -65,8 +63,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); @@ -74,13 +71,11 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { local_service->backend().platform()->Name().c_str(), module.ToString(HloPrintOptions::ShortParsable()).c_str()); } else { - const ComputationTracker& tracker = local_service->computation_tracker(); - UserComputation* user_computation = - tracker.Resolve(computation.handle()).ConsumeValueOrDie(); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); + auto config = HloModule::CreateModuleConfigFromProto(computation.proto(), + DebugOptions()) + .ConsumeValueOrDie(); std::unique_ptr module = - tracker.BuildHloModule(versioned_handle, HloModuleConfig()) + HloModule::CreateFromProto(computation.proto(), config) .ConsumeValueOrDie(); fprintf(stdout, "%s\n", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 51f90b07c66f7d839f587350726333b9dbe6a9f0..a5dce20456c6a2402f425ebb3d575d1bb625f839 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -28,11 +28,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -48,10 +47,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); debug_options.set_xla_hlo_dump_as_graphdef(true); diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD deleted file mode 100644 index 0fa4b98d0a41a1e7c681bb2302da3b752315867b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ /dev/null @@ -1,72 +0,0 @@ -# Build file for the Hlo parser. - -licenses(["notice"]) # Apache 2.0 - -package( - default_visibility = [":friends"], -) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "hlo_lexer", - srcs = ["hlo_lexer.cc"], - hdrs = [ - "hlo_lexer.h", - "hlo_token.h", - ], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - ], -) - -cc_library( - name = "hlo_parser", - srcs = ["hlo_parser.cc"], - hdrs = ["hlo_parser.h"], - deps = [ - ":hlo_lexer", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "hlo_parser_test", - size = "small", - srcs = ["hlo_parser_test.cc"], - deps = [ - ":hlo_parser", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index d8cedad65ea68ef86b94394a1accf2c08517c0b2..f7574e0b1cc95daee6d6743ba4e2e490ee87e7c6 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -17,13 +17,16 @@ limitations under the License. // // Replays computations and shows the results on the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // Computations that require arguments can be replayed using fake data by // passing --use_fake_data on the command line. If the real data is available // in the proto and --use_fake_data is false, the real data is used. // +// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a +// textual HLO string. +// // The output format is: // // file_path: computation_name :: type:literal_str @@ -36,14 +39,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -65,136 +68,177 @@ namespace { // fields. struct Options { string fake_infeed_shape; + bool generate_fake_infeed = false; bool use_fake_data = false; bool print_result = true; int num_runs = 1; - bool xla_hlo_profile_last_run = false; }; // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // -// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; -// otherwise, no infeed is performed. -template -StatusOr> ReplayComputation(const ModuleT& module, - Client* client, - const Options& opts) { - static_assert(std::is_same::value || - std::is_same::value, - "Proto must be in HloSnapshot or SessionModule format"); - TF_ASSIGN_OR_RETURN(auto computation, client->LoadSnapshot(module)); - - std::vector> arguments; +// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided. +// If generate_fake_infeed is true, the required infeed shape is derived from +// the computation and then used to provide a fake infeed shape. +// +// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided, +// no infeed is performed. +StatusOr ReplayComputation(const HloSnapshot& module, + LocalClient* client, const Options& opts) { + XlaComputation computation(module.hlo().hlo_module()); + + // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our + // arguments. This is a bit involved, because we may have to convert from + // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our + // objects. + std::vector scoped_shaped_buffer_arguments; + std::vector> global_data_arguments; + std::vector argument_ptrs; if (opts.use_fake_data) { - arguments = MakeFakeArgumentsOrDie(computation, client); + global_data_arguments = MakeFakeArgumentsOrDie(computation, client); + for (const auto& data : global_data_arguments) { + argument_ptrs.push_back( + client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) + .ValueOrDie()); + } } else { // use recorded data if available for (const auto& proto : module.arguments()) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(proto)); - TF_ASSIGN_OR_RETURN(std::unique_ptr data, - client->TransferToServer(*literal)); - arguments.push_back(std::move(data)); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer data, + client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0)); + scoped_shaped_buffer_arguments.push_back(std::move(data)); + } + for (const auto& argument : scoped_shaped_buffer_arguments) { + argument_ptrs.push_back(&argument); } } + bool provide_infeed = false; + Shape infeed_shape; + if (!opts.fake_infeed_shape.empty()) { + StatusOr shape_status = + ShapeUtil::ParseShapeString(opts.fake_infeed_shape); + TF_CHECK_OK(shape_status.status()); + infeed_shape = std::move(shape_status).ValueOrDie(); + provide_infeed = true; + } else if (opts.generate_fake_infeed) { + for (const auto& comp : computation.proto().computations()) { + for (const auto& instruction : comp.instructions()) { + if (instruction.opcode() == HloOpcodeString(HloOpcode::kInfeed)) { + CHECK(!provide_infeed) + << "--generate_fake_infeed only works if the model has 0 or 1 " + "infeed ops, but this one has >= 2."; + provide_infeed = true; + infeed_shape = instruction.shape(); + LOG(INFO) << "Generating fake infeed shape for inferred shape: " + << ShapeUtil::HumanString(infeed_shape); + } + } + } + } // We only instantiate the thread pool if the user has requested that a - // concurrent infeed occur via the fake_infeed_shape. + // concurrent infeed occur via the fake_infeed_shape, or when + // --generate_fake_infeed is passed and there exists an infeed operation in + // the HloSnapshot. tensorflow::gtl::optional pool; - - if (!opts.fake_infeed_shape.empty()) { + std::unique_ptr data; + if (provide_infeed) { + data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); + } + auto transfer_infeed = [&data, client]() { + TF_CHECK_OK(client->TransferToInfeed(*data)); + }; + if (provide_infeed) { pool.emplace(tensorflow::Env::Default(), "infeed", /*num_threads=*/1); - pool->Schedule([opts, client]() { - StatusOr shape_status = - ShapeUtil::ParseShapeString(opts.fake_infeed_shape); - TF_CHECK_OK(shape_status.status()); - Shape shape = std::move(shape_status).ValueOrDie(); - StatusOr> data_status = MakeFakeLiteral(shape); - TF_CHECK_OK(data_status.status()); - std::unique_ptr data = std::move(data_status).ValueOrDie(); - while (true) { - TF_CHECK_OK(client->TransferToInfeed(*data)); - } + pool->Schedule([transfer_infeed]() { + // There may be several infeed buffers needed, however we don't know how + // many. If we proactively transfer too many infeed buffers, we may run + // out of memory. If we transfer too few infeed buffers, the program will + // hang. Therefore, we register a callback that is called when the infeed + // becomes empty, and in this callback we will transfer another fake + // infeed. + auto infeed_manager = xla::gpu::GetOrCreateInfeedManager(); + infeed_manager->RegisterOnEmptyCallback(transfer_infeed); + transfer_infeed(); }); } - std::vector execute_arguments; - execute_arguments.reserve(arguments.size()); - for (auto& argument : arguments) { - execute_arguments.push_back(argument.get()); + std::vector argument_layouts; + for (const auto& param : computation.proto().program_shape().parameters()) { + argument_layouts.push_back(¶m); } + std::unique_ptr executable = + client->Compile(computation, argument_layouts, ExecutableBuildOptions()) + .ValueOrDie(); // Run the computation num_runs times, and return the result from the last // execution. - std::unique_ptr result; + StreamExecutorMemoryAllocator allocator( + client->platform(), + {client->platform()->ExecutorForDevice(0).ValueOrDie()}); + tensorflow::gtl::optional result; for (int i = 0; i < opts.num_runs; ++i) { ExecutionProfile profile; - ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) { - execution_options.mutable_debug_options()->set_xla_hlo_profile(true); - } + ExecutableRunOptions run_options; + run_options.set_execution_profile(&profile); + run_options.set_allocator(&allocator); - if (opts.print_result) { - TF_ASSIGN_OR_RETURN( - result, client->ExecuteAndTransfer(computation, execute_arguments, - &execution_options, &profile)); - } else { - // If we're not printing the result, execute the computation but don't - // bother retrieving the result. This can be a significant speedup. - TF_RETURN_IF_ERROR(client - ->Execute(computation, execute_arguments, - &execution_options, &profile) - .status()); - } + TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options)); LOG(INFO) << "Execution took " << static_cast(profile.compute_time_ns()) / 1e9 << "s"; } - return std::move(result); + // Check that --num_runs > 0, otherwise *result below will fail with an + // unhelpful error (because the loop didn't run any iterations). + CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0"; + TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, + client->ShapedBufferToLiteral(*result)); + return std::move(*result_literal); } -int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { - Client* client = ClientLibrary::LocalClientOrDie(); +StatusOr ParseInputFile(const string& filename, + const Options& opts) { tensorflow::Env* env = tensorflow::Env::Default(); - int exit_status = EXIT_SUCCESS; - for (char* arg : args) { - HloSnapshot snapshot; - auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot); - if (status.ok()) { - StatusOr> result_status = - ReplayComputation(snapshot, client, opts); - if (!result_status.ok()) { - fprintf(stderr, "%s: error: %s\n", arg, - result_status.status().ToString().c_str()); - exit_status = EXIT_FAILURE; - continue; - } + HloSnapshot snapshot; + if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) { + return snapshot; + } + CHECK(opts.use_fake_data) + << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " + "and textual HLO don't carry real data."; + fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", + filename.c_str()); - std::unique_ptr result = result_status.ConsumeValueOrDie(); - if (result != nullptr) { - fprintf(stdout, "%s: %s :: %s:%s\n", arg, - snapshot.hlo().hlo_module().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); - if (snapshot.has_result()) { - std::unique_ptr literal = - Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); - fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(snapshot.result().shape()).c_str(), - literal->ToString().c_str()); - } - } + if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) { + return snapshot; + } + fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str()); + string contents; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents)); + StatusOr> module = ParseHloString(contents); + if (module.ok()) { + *snapshot.mutable_hlo()->mutable_hlo_module() = + module.ValueOrDie()->ToProto(); + return snapshot; + } + fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", + filename.c_str()); + return InvalidArgument("Could not parse %s.", filename.c_str()); +} +int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + int exit_status = EXIT_SUCCESS; + for (char* arg : args) { + StatusOr maybe_snapshot = ParseInputFile(arg, opts); + if (!maybe_snapshot.ok()) { continue; } - fprintf(stderr, "%s: is not HloSnapshot: %s. Trying as SessionModule...\n", - arg, status.ToString().c_str()); - - SessionModule module; - TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); - StatusOr> result_status = - ReplayComputation(module, client, opts); + HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie(); + StatusOr result_status = ReplayComputation(snapshot, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); @@ -202,16 +246,17 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { continue; } - std::unique_ptr result = result_status.ConsumeValueOrDie(); - if (result != nullptr) { - fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); - if (module.has_result()) { + if (opts.print_result) { + Literal result = std::move(result_status).ValueOrDie(); + fprintf(stdout, "%s: %s :: %s:%s\n", arg, + snapshot.hlo().hlo_module().name().c_str(), + ShapeUtil::HumanString(result.shape()).c_str(), + result.ToString().c_str()); + if (snapshot.has_result()) { std::unique_ptr literal = - Literal::CreateFromProto(module.result()).ConsumeValueOrDie(); + Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(module.result().shape()).c_str(), + ShapeUtil::HumanString(snapshot.result().shape()).c_str(), literal->ToString().c_str()); } } @@ -236,9 +281,9 @@ int main(int argc, char** argv) { "Number of times to run each computation"), tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), - tensorflow::Flag( - "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run, - "Pass --xla_hlo_profile the last time we run the computation."), + tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed, + "Whether a fake infeed shape should be generated " + "derived from the computation"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index 1f3340cbc6afa9bda8bf639d01b8185968f79a4d..4e53fafcc97ff53afc5713e7ed8ee5222fac316b 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -18,7 +18,7 @@ limitations under the License. // Shows the signature (ProgramShape) of binary snapshot proto(s) on the command // line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The output format is: @@ -31,9 +31,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -49,13 +48,14 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + auto computation = client->LoadSnapshot(module).ConsumeValueOrDie(); std::unique_ptr shape = client->GetComputationShape(computation).ConsumeValueOrDie(); - fprintf(stdout, "%s: %s :: %s\n", arg, module.entry().name().c_str(), + fprintf(stdout, "%s: %s :: %s\n", arg, + module.hlo().hlo_module().name().c_str(), ShapeUtil::HumanString(*shape).c_str()); } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index be33bd6dd1304fa8fc6e5aed1d4c4d65bf97e692..b4f45cc972d3d397ddff8e8d9163d1fef387392f 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -218,6 +218,12 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); // Passed-varargs variant of the InvalidArgument factory above. Status InvalidArgumentV(const char* format, va_list args); +template +Status InvalidArgumentStrCat(Args&&... concat) { + return InvalidArgument( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + template Status UnimplementedStrCat(Args&&... concat) { return Unimplemented( @@ -486,6 +492,12 @@ bool c_is_sorted(const C& c) { return std::is_sorted(std::begin(c), std::end(c)); } +template +bool c_is_sorted(const C& c, Compare&& comp) { + return std::is_sorted(std::begin(c), std::end(c), + std::forward(comp)); +} + template auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) { return std::adjacent_find(std::begin(c), std::end(c)); @@ -514,12 +526,29 @@ typename std::decay::type c_accumulate(const Sequence& sequence, T&& init, std::forward(binary_op)); } +template +typename std::iterator_traits< + decltype(std::begin(std::declval()))>::difference_type +c_count_if(const C& c, Pred&& pred) { + return std::count_if(std::begin(c), std::end(c), std::forward(pred)); +} + template int64 FindIndex(const C& c, Value&& value) { auto it = c_find(c, std::forward(value)); return std::distance(c.begin(), it); } +template +void InsertAt(C* c, int64 index, Value&& value) { + c->insert(c->begin() + index, std::forward(value)); +} + +template +void EraseAt(C* c, int64 index) { + c->erase(c->begin() + index); +} + // Returns true if `x` fits in 32-bits. template bool IsInt32(T x) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f619b8dc24038af64a27fc0565c74447ca9d09cf..6f07e4606bef015214f2c564515c8258a906205b 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -17,7 +17,6 @@ syntax = "proto3"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; -import "tensorflow/compiler/xla/service/session.proto"; package xla; @@ -226,22 +225,6 @@ message ExecutionOptions { repeated DeviceHandle device_handles = 5; } -message SnapshotComputationRequest { - ComputationHandle computation = 1; -} - -message SnapshotComputationResponse { - SessionModule module = 1; -} - -message LoadComputationSnapshotRequest { - SessionModule module = 1; -} - -message LoadComputationSnapshotResponse { - ComputationHandle computation = 1; -} - message GetDeviceHandlesRequest { int64 device_count = 1; } @@ -300,11 +283,6 @@ message ResetDeviceRequest { message ResetDeviceResponse { } -message ComputationStatsRequest { - ComputationHandle computation = 1; - DebugOptions debug_options = 2; -} - message ComputationGraphStatsRequest { HloModuleProto computation = 1; DebugOptions debug_options = 2; @@ -314,14 +292,6 @@ message ComputationStatsResponse { ComputationStats stats = 1; } -message ComputationRequest { - string name = 1; -} - -message ComputationResponse { - ComputationHandle computation = 1; -} - message CreateChannelHandleRequest { } @@ -336,24 +306,6 @@ message UnregisterRequest { message UnregisterResponse { } -message SetReturnValueRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message SetReturnValueResponse { -} - -message ExecuteRequest { - reserved 3, 4; - - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 5; -} - message ExecuteGraphRequest { HloModuleProto computation = 1; repeated GlobalDataHandle arguments = 2; @@ -362,10 +314,6 @@ message ExecuteGraphRequest { ExecutionOptions execution_options = 3; } -message ExecuteParallelRequest { - repeated ExecuteRequest requests = 1; -} - message ExecuteGraphParallelRequest { repeated ExecuteGraphRequest requests = 1; } @@ -379,21 +327,6 @@ message ExecuteParallelResponse { repeated ExecuteResponse responses = 1; } -message ExecuteAsyncRequest { - reserved 3, 4; - - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 6; -} - -message ExecuteAsyncResponse { - // A handle to the execution launched asynchronously. - ExecutionHandle execution = 1; -} - message WaitForExecutionRequest { ExecutionHandle execution = 1; } @@ -403,31 +336,13 @@ message WaitForExecutionResponse { ExecutionProfile profile = 2; } -message IsConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - int64 num_parameters = 3; -} - -message IsConstantResponse { - bool is_constant = 1; -} - -message ComputeConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - Layout output_layout = 3; - repeated LiteralProto parameters = 4; -} - message ComputeConstantGraphRequest { HloModuleProto computation = 1; Layout output_layout = 2; } message ComputeConstantResponse { - // A LiteralProto is returned directly for this request, instead of a - // ComputationDataHandle. + // A LiteralProto is returned directly for this request. LiteralProto literal = 1; } @@ -469,14 +384,6 @@ message LoadDataResponse { int64 nanoseconds = 5; } -message SpecializeRequest { - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; -} - -message SpecializeResponse { -} - message GetShapeRequest { GlobalDataHandle data = 1; } @@ -485,14 +392,6 @@ message GetShapeResponse { Shape shape = 1; } -message GetComputationShapeRequest { - ComputationHandle computation = 1; -} - -message GetComputationShapeResponse { - ProgramShape program_shape = 1; -} - message UnpackRequest { GlobalDataHandle data = 1; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 750d72d797b4f8680e13597ac02f6f9fa6e37bcd..0af73e8a93060f4569ddef9697b89a6fa2b8674b 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -66,11 +66,16 @@ enum PrimitiveType { // in the dimensions field. TUPLE = 13; - // An opaque type used for passing context specific data to a custom - // operation. + // An opaque type used for passing context-specific data to a custom + // operation. Shapes of this primitive type will have empty dimensions and + // tuple_shapes fields. OPAQUE = 14; - // Next = 17 + // A token type threaded between side-effecting operations. Shapes of this + // primitive type will have empty dimensions and tuple_shapes fields. + TOKEN = 17; + + // Next = 18 } // Describes the value held inside padding elements. @@ -271,12 +276,6 @@ message ExecutionProfile { int64 compute_and_transfer_time_ns = 5; } -// Handle given to a user that represents a computation that the user builds up -// before execution. -message ComputationHandle { - int64 handle = 1; -} - // Handle given to a user that represents an execution that the user launched // asynchronously on the device. message ExecutionHandle { @@ -290,13 +289,6 @@ message GlobalDataHandle { int64 handle = 1; } -// Handle given to a user that represents a data result in a computation. -// This is used to pass to subsequent computations that depends upon the data as -// an operand. -message ComputationDataHandle { - int64 handle = 1; -} - // Handle given to a user that represents a replicated virtual device. Each // replicated device represents N physical devices for execution where N is the // number of replicas. @@ -436,44 +428,6 @@ message GatherDimensionNumbers { int64 index_vector_dim = 4; } -// Operation requests that are all collected as a tagged union with a oneof -// field in OpRequest. - -message ConstantRequest { - LiteralProto literal = 2; -} - -message GetTupleElementRequest { - ComputationDataHandle operand = 2; - int64 index = 3; -} - -message SliceRequest { - ComputationDataHandle operand = 2; - repeated int64 start_indices = 3; - repeated int64 limit_indices = 4; - repeated int64 strides = 5; -} - -message DynamicSliceRequest { - // Operand from which to slice at dynamic 'start_indices'. - ComputationDataHandle operand = 2; - // Dynamically computed 'start_indices' for slice operation. - ComputationDataHandle start_indices = 3; - // Slice sizes for each dimension (note that indices calculations are computed - // modulo dimension sizes to avoid out-of-bound array accesses). - repeated int64 slice_sizes = 4; -} - -message DynamicUpdateSliceRequest { - // Operand on which slice 'update' is to be applied. - ComputationDataHandle operand = 2; - // The slice update to apply to 'operand'. - ComputationDataHandle update = 3; - // Dynamically computed start indices for the update slice operation. - ComputationDataHandle start_indices = 4; -} - message ConvolutionDimensionNumbers { // The number of the dimension that represents batch in the input. int64 input_batch_dimension = 7; @@ -511,13 +465,6 @@ message ConvolutionDimensionNumbers { // Next = 13 }; -message ConvolveRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; // This is the filter/kernel. - Window window = 4; // Describes the filter/kernel. - ConvolutionDimensionNumbers dimension_numbers = 5; -} - enum FftType { FFT = 0; // Forward FFT; complex in, complex out. IFFT = 1; // Inverse FFT; complex in, complex out. @@ -526,56 +473,6 @@ enum FftType { // fft_length real out } -message FftRequest { - FftType fft_type = 1; - repeated int64 fft_length = 2; // Multivalent for higher-order FFT. - ComputationDataHandle operand = 3; -} - -message InfeedRequest { - // The shape of the data returned by reading the device's infeed buffer. - Shape shape = 2; - - // Additional infeed configuration for the backend. - bytes config = 3; -} - -message OutfeedRequest { - // The shape of the data returned by reading the device's outfeed buffer. - Shape shape = 1; - - // Operand to the Outfeed. Supports tuple. - ComputationDataHandle operand = 2; - - // Backend-specific information for how to perform the outfeed. - bytes outfeed_config = 3; -} - -message CallRequest { - ComputationHandle to_apply = 2; - repeated ComputationDataHandle operands = 3; -} - -message CustomCallRequest { - string call_target_name = 2; - repeated ComputationDataHandle operands = 3; - Shape shape = 4; -} - -message HostComputeRequest { - // Operand to the HostCompute. Supports tuple. - repeated ComputationDataHandle operands = 1; - - // Name used to identify HostSend/Recv channels. - string channel_name = 2; - - // Cost estimate in nanoseconds. - int64 cost_estimate_ns = 3; - - // The shape of any data returned by host. - Shape shape = 4; -} - message DotDimensionNumbers { // The dimension numbers that represent the 'lhs' contracting dimensions. repeated int64 lhs_contracting_dimensions = 1; @@ -587,291 +484,6 @@ message DotDimensionNumbers { repeated int64 rhs_batch_dimensions = 4; }; -message DotRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; - DotDimensionNumbers dimension_numbers = 4; -} - -message MapRequest { - repeated ComputationDataHandle operands = 2; - ComputationHandle to_apply = 3; - repeated ComputationDataHandle static_operands = 4; - // The dimensions over which to map. - // Example mapping a Dot operation along the batch dimension 0: - // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3] - // Map({operand0, operand1}, Dot, {0}) - repeated int64 dimensions = 5; -} - -message ReduceRequest { - // Operand to the reduction. - ComputationDataHandle operand = 2; - - // Initial value for the reduction. This must be consistent with the result - // shape of to_apply. - ComputationDataHandle init_value = 3; - - // The dimensions to reduce over. - repeated int64 dimensions = 4; - - // The computation to apply in the reduction. - ComputationHandle to_apply = 5; -} - -message ReduceWindowRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle init_value = 3; - Window window = 4; - ComputationHandle to_apply = 5; -} - -message BatchNormTrainingRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - float epsilon = 4; - int64 feature_index = 5; -} - -message BatchNormInferenceRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - ComputationDataHandle mean = 4; - ComputationDataHandle variance = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message BatchNormGradRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle mean = 3; - ComputationDataHandle variance = 4; - ComputationDataHandle grad_output = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message CrossReplicaSumRequest { - ComputationDataHandle operand = 2; -} - -message SelectAndScatterRequest { - // Operand array on which the windows slide. - ComputationDataHandle operand = 2; - - // Source array for the data to scatter. - ComputationDataHandle source = 3; - - // Initial scalar value for each element in the output. - ComputationDataHandle init_value = 4; - - // Window configuration. - Window window = 5; - - // Binary function used to select an element from each window. - ComputationHandle select = 6; - - // Binary function used to combine each scattered value from source with the - // current output value at the selected location. - ComputationHandle scatter = 7; -} - -message ReverseRequest { - ComputationDataHandle operand = 2; - repeated int64 dimensions = 3; -} - -message BroadcastRequest { - ComputationDataHandle operand = 2; - repeated int64 broadcast_sizes = 3; -} - -message PadRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle padding_value = 3; - PaddingConfig padding_config = 4; -} - -message ReshapeRequest { - ComputationDataHandle operand = 2; - - // The dimension order for collapse (from fastest-changing to slowest). - repeated int64 dimensions = 3; - - // The new dimension sizes (from dimension 0 to n-1). - repeated int64 new_sizes = 4; -} - -message TransposeRequest { - ComputationDataHandle operand = 2; - - // The permutation of the operand's dimensions (in the range 0 to n-1). - repeated int64 dimensions = 3; -} - -message ParameterRequest { - Shape shape = 2; - int64 parameter = 3; - string name = 4; -} - -message GetLocalShapeRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message GetLocalShapeResponse { - Shape shape = 1; -} - -message TraceRequest { - string tag = 2; - ComputationDataHandle operand = 3; -} - -message ConvertRequest { - ComputationDataHandle operand = 2; - PrimitiveType new_element_type = 3; -} - -message ConcatenateRequest { - repeated ComputationDataHandle operands = 2; - // The dimension in which we concatenate; e.g. if you had dimension arrays of - // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1]. - // Attempting to concatenate those in dimension 1 would produce an error, as - // 4 != 5 (and there is no ragged array support). - int64 dimension = 3; -} - -message ConditionalRequest { - ComputationDataHandle predicate = 2; - ComputationDataHandle true_operand = 3; - ComputationHandle true_computation = 4; - ComputationDataHandle false_operand = 5; - ComputationHandle false_computation = 6; -} - -message WhileRequest { - ComputationHandle condition = 2; - ComputationHandle body = 3; - ComputationDataHandle init = 4; -} - -enum UnaryOperation { - UNOP_INVALID = 0; - - // Elementwise, logical negation on booleans and bitwise negation on ints. - UNOP_NOT = 1; - - // Elementwise, computes e^x. - UNOP_EXP = 2; - - // Elementwise, computes -x. - UNOP_NEGATE = 3; - - // Puts the elements in the operand into sorted order. - UNOP_SORT = 4; - - // Elementwise, computes tanh(x). - UNOP_TANH = 5; - - // Elementwise, computes the natural logarithm of x. - UNOP_LOG = 6; - - // Elementwise, computes the floor of x. - UNOP_FLOOR = 7; - - // Elementwise, computes the ceil of x. - UNOP_CEIL = 8; - - // Elementwise, computes the abs of x. - UNOP_ABS = 9; - - // Elementwise, computes the sign of x. - UNOP_SIGN = 10; - - // Elementwise, tests if values are finite (not NaN or inf) - UNOP_IS_FINITE = 11; - - // Elementwise, computes the cosine of x. - UNOP_COS = 12; - - // Elementwise, computes the sine of x. - UNOP_SIN = 13; - - // Elementwise, rounds x to nearest integral value, rounding half-way cases - // away from zero. - UNOP_ROUND_NEAREST_AFZ = 14; - - // Elementwise, extract real component of complex x. - UNOP_REAL = 15; - - // Elementwise, extract real component of complex x. - UNOP_IMAG = 16; - - // Elementwise, computes clz(x). - UNOP_CLZ = 17; -} - -message UnaryOpRequest { - UnaryOperation unop = 2; - ComputationDataHandle operand = 3; -} - -enum BinaryOperation { - BINOP_INVALID = 0; - - // Arithmetic operations. - BINOP_ADD = 1; - BINOP_DIV = 2; - BINOP_MUL = 3; - BINOP_SUB = 4; - - // Comparison operators. - BINOP_EQ = 5; - BINOP_GE = 6; - BINOP_GT = 7; - BINOP_LE = 8; - BINOP_LT = 9; - BINOP_NE = 10; - - // Element-wise maximum. - BINOP_MAX = 14; - - // Element-wise minimum. - BINOP_MIN = 15; - - // Raises the left-hand-side to the right-hand-side power. - BINOP_POW = 16; - - // Remainder operation. - BINOP_REM = 17; - - // Element-wise, logical operators on booleans and bitwise operators on ints. - BINOP_AND = 18; - BINOP_OR = 19; - - BINOP_SHIFT_LEFT = 20; - BINOP_SHIFT_RIGHT_ARITHMETIC = 21; - BINOP_SHIFT_RIGHT_LOGICAL = 22; - - // Complex from real, imag. - BINOP_COMPLEX = 23; - - // Computes the 4-quadrant arctangent of the y, x input arguments. - BINOP_ATAN2 = 24; -} - -message BinaryOpRequest { - BinaryOperation binop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - repeated int64 broadcast_dimensions = 5; -} - enum RandomDistribution { RNG_INVALID = 0; @@ -886,67 +498,6 @@ enum RandomDistribution { // Next: 4 } -message RngRequest { - RandomDistribution distribution = 2; - repeated ComputationDataHandle parameter = 3; - Shape shape = 4; -} - -enum TernaryOperation { - TRIOP_INVALID = 0; - - // Given a predicate and two operands, selects operand0 if the predicate is - // true and operand1 if the predicate is false. - TRIOP_SELECT = 1; - - // Given a min, max and an operand returns the operand if between min and max, - // else returns min if operand is less than min or max if operand is greater - // than max. - TRIOP_CLAMP = 3; -} - -message TernaryOpRequest { - TernaryOperation triop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - ComputationDataHandle ehs = 5; -} - -enum VariadicOperation { - VAROP_INVALID = 0; - - // Creates a tuple from its operands. - VAROP_TUPLE = 1; -} - -message VariadicOpRequest { - VariadicOperation varop = 2; - repeated ComputationDataHandle operands = 3; -} - -message ReducePrecisionRequest { - ComputationDataHandle operand = 1; - int32 exponent_bits = 2; - int32 mantissa_bits = 3; -} - -message SendRequest { - ComputationDataHandle operand = 1; - ChannelHandle channel_handle = 2; -} - -message RecvRequest { - Shape shape = 1; - ChannelHandle channel_handle = 2; -} - -message GatherRequest { - ComputationDataHandle input = 1; - ComputationDataHandle gather_indices = 2; - GatherDimensionNumbers dimension_numbers = 3; - repeated int64 window_bounds = 4; -} - message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -977,59 +528,3 @@ message OpSharding { // to. repeated OpSharding tuple_shardings = 5; } - -message OpRequest { - ComputationHandle computation = 1; - OpMetadata metadata = 33; - OpSharding sharding = 40; - - oneof op { - BinaryOpRequest binary_op_request = 2; - BroadcastRequest broadcast_request = 3; - CallRequest call_request = 4; - ConcatenateRequest concatenate_request = 5; - ConstantRequest constant_request = 6; - ConvertRequest convert_request = 7; - ConvolveRequest convolve_request = 8; - CrossReplicaSumRequest cross_replica_sum_request = 9; - CustomCallRequest custom_call_request = 10; - DotRequest dot_request = 43; - DynamicSliceRequest dynamic_slice_request = 11; - DynamicUpdateSliceRequest dynamic_update_slice_request = 12; - GetTupleElementRequest get_tuple_element_request = 13; - InfeedRequest infeed_request = 14; - MapRequest map_request = 15; - PadRequest pad_request = 16; - ParameterRequest parameter_request = 17; - ReducePrecisionRequest reduce_precision_request = 36; - ReduceRequest reduce_request = 18; - ReduceWindowRequest reduce_window_request = 19; - ReshapeRequest reshape_request = 20; - ReverseRequest reverse_request = 21; - RngRequest rng_request = 22; - SelectAndScatterRequest select_and_scatter_request = 23; - SliceRequest slice_request = 24; - TernaryOpRequest ternary_op_request = 25; - TraceRequest trace_request = 26; - TransposeRequest transpose_request = 34; - UnaryOpRequest unary_op_request = 27; - VariadicOpRequest variadic_op_request = 28; - WhileRequest while_request = 29; - SendRequest send_request = 30; - RecvRequest recv_request = 31; - OutfeedRequest outfeed_request = 32; - BatchNormTrainingRequest batch_norm_training_request = 35; - BatchNormGradRequest batch_norm_grad_request = 37; - BatchNormInferenceRequest batch_norm_inference_request = 38; - FftRequest fft_request = 41; - ConvertRequest bitcast_convert_request = 42; - ConditionalRequest conditional_request = 44; - HostComputeRequest host_compute_request = 45; - GatherRequest gather_request = 46; - // Next: 47 - } -} - -message OpResponse { - ComputationDataHandle output = 1; -} diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index abdbdb4cd22ff38a0fae89af10c600a178d9a3d4..50b1ae5cc3cba2d6ac89c4415a3419ffdf7aec93 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -31,13 +31,15 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/contrib/autograph", "//tensorflow/contrib/constrained_optimization", + "//tensorflow/contrib/control_flow", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/data", - "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/deprecated:deprecated_py", + "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/contrib/estimator:estimator_py", @@ -71,6 +73,7 @@ py_library( "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/meta_graph_transform", "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/contrib/mixed_precision:mixed_precision", "//tensorflow/contrib/model_pruning", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", @@ -82,7 +85,6 @@ py_library( "//tensorflow/contrib/proto", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", - "//tensorflow/contrib/autograph", "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/recurrent:recurrent_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 9f5459f41da3e5a13286f7002e4b519978bc189b..ad8c40395c2cdcc5e4288e04bb2115bd3627cdc9 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -30,6 +30,7 @@ from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder from tensorflow.contrib import compiler from tensorflow.contrib import constrained_optimization +from tensorflow.contrib import control_flow from tensorflow.contrib import copy_graph from tensorflow.contrib import crf from tensorflow.contrib import cudnn_rnn @@ -60,6 +61,7 @@ from tensorflow.contrib import lookup from tensorflow.contrib import losses from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics +from tensorflow.contrib import mixed_precision from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index 62d1b1cf079d04d50e4899cfd9ba1d405ee1efb9..881808a98bfd688c2efaa8beb5b8f11a2527fee8 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -11,6 +11,16 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "tf_py_test") +py_library( + name = "all_reduce_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":all_reduce", + "//tensorflow/python:util", + ], +) + py_library( name = "all_reduce", srcs = [ diff --git a/tensorflow/python/keras/applications/densenet/__init__.py b/tensorflow/contrib/all_reduce/__init__.py similarity index 54% rename from tensorflow/python/keras/applications/densenet/__init__.py rename to tensorflow/contrib/all_reduce/__init__.py index 6b8ea83920733a3a442171616ab460ffaf831521..f9824f4cfbf83d9b001a58cafe582226e96c076f 100644 --- a/tensorflow/python/keras/applications/densenet/__init__.py +++ b/tensorflow/contrib/all_reduce/__init__.py @@ -12,18 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""DenseNet Keras applications.""" +"""All-reduce implementations.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.densenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201 -from tensorflow.python.keras._impl.keras.applications.densenet import preprocess_input +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.all_reduce.python.all_reduce import * -del absolute_import -del division -del print_function +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'build_ring_all_reduce', + 'build_recursive_hd_all_reduce', + 'build_shuffle_all_reduce', + 'build_nccl_all_reduce', + 'build_nccl_then_ring', + 'build_nccl_then_recursive_hd', + 'build_nccl_then_shuffle', + 'build_shuffle_then_ring', + 'build_shuffle_then_shuffle' +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index 60306ebdc6cddb04e8807bfd495fa92a56e55ecd..c10179ba8b290b6209f5567d6323df4bcf711585 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -72,7 +72,7 @@ cc_binary( "-s", "-Wl,--gc-sections", "-Wl,--version-script", # This line must be directly followed by LINKER_SCRIPT. - LINKER_SCRIPT, + "$(location {})".format(LINKER_SCRIPT), ]), linkshared = 1, linkstatic = 1, diff --git a/tensorflow/contrib/android/jni/run_stats_jni.cc b/tensorflow/contrib/android/jni/run_stats_jni.cc index 707853b59befc2625145ad96952fbf9f66d62b43..30de7b59af79cb36ee266a15bb6e668c2e3f628a 100644 --- a/tensorflow/contrib/android/jni/run_stats_jni.cc +++ b/tensorflow/contrib/android/jni/run_stats_jni.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/contrib/android/jni/run_stats_jni.h" #include + #include #include "tensorflow/core/protobuf/config.pb.h" @@ -73,7 +74,8 @@ JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz, StatSummarizer* s = requireHandle(env, handle); if (s == nullptr) return nullptr; std::stringstream ret; - ret << s->GetStatsByMetric("Top 10 CPU", StatSummarizer::BY_TIME, 10) + ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME, + 10) << s->GetStatsByNodeType() << s->ShortSummary(); return env->NewStringUTF(ret.str().c_str()); } diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/contrib/autograph/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..a4aec8c74a9ad1418072471a5d3cde8c3b968a38 --- /dev/null +++ b/tensorflow/contrib/autograph/CONTRIBUTING.md @@ -0,0 +1,48 @@ +# How to Contribute + +We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below. + +## TensorFlow Code of Conduct +Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md). + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult [GitHub +Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +After a pull request is approved, we merge it. Note our merging process differs +from GitHub in that we pull and submit the change into an internal version +control system. This system automatically pushes a git commit to the GitHub +repository (with credit to the original author) and closes the pull request. + +## Style + +See the [AutoGraph style guide](STYLE_GUIDE.md). + +## Unit tests + +Please include unit tests when contributing new features ([example here](converters/continue_statements_test.py)), as they help to a) prove that your code works correctly, and b) guard against future breaking +changes to lower the maintenance cost. +It's also helpful to check that any +changes you propose do not break existing unit tests. You can run tests using the command, + +```shell +bazel test --config=opt --copt=-O3 --copt=-march=native \ + //tensorflow/contrib/autograph/... +``` + +from the root of the `tensorflow` repository. For more details see the [main TensorFlow Contributing File](../../CONTRIBUTING.md) diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index 0ba99c396fc1c8ee1e12fbb4fe0293ee52ed9bc9..674859bed4ec157d5d5b33b6fc015c930e54b392 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,6 +1,6 @@ # AutoGraph -IMPORTANT: AutoGraph is pre-alpha, under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! +IMPORTANT: AutoGraph is alpha software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). AutoGraph is a Python to TensorFlow compiler. diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md new file mode 100644 index 0000000000000000000000000000000000000000..866e5f583a34570dfddc733f57561ed1d2b7c5bf --- /dev/null +++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md @@ -0,0 +1,75 @@ +# AutoGraph Style Guide + +This page contains style decisions that developers should follow when +contributing code to AutoGraph. + +## TensorFlow Style + +Follow the [TensorFlow style +guide](https://www.tensorflow.org/community/style_guide), the [documentation +guide](https://www.tensorflow.org/community/documentation) and the +[Google Python style guide](https://google.github.io/styleguide/pyguide.html). + +Naming conventions: + +1. The name is TensorFlow, not Tensorflow. +2. The name is AutoGraph, not Autograph. + +## AutoGraph Style + +Below are AutoGraph-specific conventions. In the event of conflict, +it supercedes all previous conventions. + +1. __Citations in Docstrings.__ Write a `#### References` subsection at the + bottom of any docstring with citations. Use ICLR’s bibliography style to + write references; for example, order entries by the first author's last + name. Add a link to the paper if the publication is open source (ideally, + arXiv). + + Write in-paragraph citations in general, e.g., [(Tran and Blei, 2018)][1]. + Write in-text citations when the citation is a noun, e.g., [Tran and Blei + (2018)][1]. Write citations with more than two authors using et al., e.g., + [(Tran et al., 2018)][1]. Separate multiple citations with semicolon, e.g., + ([Tran and Blei, 2018][1]; [Gelman and Rubin, 1992][2]). + + Examples: + + ```none + #### References + + # technical report + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + + # journal + [2]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation + Using Multiple Sequences. _Statistical Science_, 7(4):457-472, 1992. + + # arXiv preprint + # use "et al." for papers with too many authors to maintain + [3]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech + Synthesis. _arXiv preprint arXiv:1711.10433_, 2017. + https://arxiv.org/abs/1711.10433 + + # conference + [4]: Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, and Roger Grosse. + Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches. In _International Conference on Learning + Representations_, 2018. + https://arxiv.org/abs/1803.04386 + ``` + +2. Avoid LaTeX in docstrings. + + * It is not rendered in many (if not most) editors and can be hard to read + for both LaTeX experts and non-experts. + +3. Write docstring and comment math using ASCII friendly notation; python using + operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`, + `sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx: + x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`. + + * The more we stick to python style, the more someone can + copy/paste/execute. + * Python style is usually easier to read as ASCII. diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index 3386c4eca4b93e850f6fe3c6239d29c61d787ece..dbdbad8f4c91c725294baa36acebbaf5b5e8cf5c 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -23,18 +23,37 @@ from __future__ import print_function # TODO(mdan): Bring only the relevant symbols to the top level. from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph.impl.api import convert from tensorflow.contrib.autograph.impl.api import converted_call from tensorflow.contrib.autograph.impl.api import do_not_convert from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import to_code from tensorflow.contrib.autograph.impl.api import to_graph +from tensorflow.contrib.autograph.impl.directives import set_element_type +from tensorflow.contrib.autograph.impl.directives import set_loop_options +from tensorflow.contrib.autograph.impl.special_functions import stack from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'utils', 'convert', 'converted_call', 'do_not_convert', 'RunMode', - 'to_code', 'to_graph', 'AutographParseError' + # Main API + 'RunMode', + 'convert', + 'converted_call', + 'do_not_convert', + 'to_code', + 'to_graph', + # Overloaded operators + 'operators', + # Special functions and directives + 'set_element_type', + 'set_loop_options', + 'stack', + # Exceptions + 'AutographParseError', + # Utilities: to be removed + 'utils', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index 8f9bffa55e44e4942bb3845945b3d440c7957cc9..284ad84be566199adaaa1ab641d37528ae4dfd2d 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -31,6 +31,7 @@ py_library( "name_scopes.py", "side_effect_guards.py", "single_return.py", + "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -208,3 +209,14 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 1be1c96dd31bf05b746fae6a2b02774e20ca0c4f..775d92c1d9f8bc35d1eda62f3f3ef7ee43414779 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -32,14 +32,6 @@ CONTROL_VAR_NAME = 'control_var_name' class BreakStatementTransformer(transformer.Base): """Canonicalizes break statements into additional conditionals.""" - def _track_body(self, nodes, break_var): - self.enter_local_scope() - self.set_local(CONTROL_VAR_NAME, break_var) - nodes = self.visit_block(nodes) - break_used = self.get_local(BREAK_USED, False) - self.exit_local_scope() - return nodes, break_used - def visit_Break(self, node): self.set_local(BREAK_USED, True) var_name = self.get_local(CONTROL_VAR_NAME) @@ -54,6 +46,7 @@ class BreakStatementTransformer(transformer.Base): """Prevents the block from executing if var_name is set.""" if not block: return block + template = """ if not var_name: block @@ -64,9 +57,17 @@ class BreakStatementTransformer(transformer.Base): block=block) return node + def _track_body(self, nodes, break_var): + self.enter_local_scope() + self.set_local(CONTROL_VAR_NAME, break_var) + nodes = self.visit_block(nodes) + break_used = self.get_local(BREAK_USED, False) + self.exit_local_scope() + return nodes, break_used + def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break__', scope.referenced) + break_var = self.context.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._track_body(node.body, break_var) @@ -74,6 +75,10 @@ class BreakStatementTransformer(transformer.Base): node.orelse = self.visit_block(node.orelse) if break_used: + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + template = """ var_name = False while test and not var_name: @@ -81,20 +86,18 @@ class BreakStatementTransformer(transformer.Base): else: orelse """ - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). node = templates.replace( template, var_name=break_var, test=node.test, body=node.body, - orelse=self._guard_if_present(node.orelse, break_var)) + orelse=guarded_orelse) return node def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break__', scope.referenced) + break_var = self.context.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) @@ -103,19 +106,32 @@ class BreakStatementTransformer(transformer.Base): node.orelse = self.visit_block(node.orelse) if break_used: - node.orelse = self._guard_if_present(node.orelse, break_var) + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + extra_test = templates.replace_as_expression( + 'not var_name', var_name=break_var) + + # The extra test is hidden in the AST, which will confuse the static + # analysis. To mitigate that, we insert a no-op statement that ensures + # the control variable is marked as used. + # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = False - for_stmt + for target in iter_: + (var_name,) + body + else: + orelse """ - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). node = templates.replace( template, var_name=break_var, - for_stmt=node) - extra_test = templates.replace_as_expression( - 'not var_name', var_name=break_var) + iter_=node.iter, + target=node.target, + body=node.body, + orelse=guarded_orelse) + anno.setanno(node[1], 'extra_test', extra_test) return node diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index 317711a866f731de1b497295a2752dee0eb544f5..231e4ee35a72f51845a476d9f605986ac73b4676 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -31,9 +31,6 @@ class BuiltinFunctionTransformer(transformer.Base): TF equivalent, like `len`. """ - def __init__(self, context): - super(BuiltinFunctionTransformer, self).__init__(context) - def _convert_builtin(self, node): template = """ ag__.utils.dynamic_builtin(func, args) @@ -51,7 +48,7 @@ class BuiltinFunctionTransformer(transformer.Base): # TODO(mdan): This won't work if the function was hidden. # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead. if (isinstance(node.func, gast.Name) and - node.func.id in ('len', 'range', 'xrange')): + node.func.id in ('len', 'range', 'xrange', 'float', 'int')): return self._convert_builtin(node) # Print needs to be handled separately because it can be read as statement. if isinstance(node.func, gast.Name) and node.func.id == 'print': diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index 554f0471d44d54194c45c3855b1483796ae65a6a..b6ecdcb7809b1ad7e7461324cb6a110ef4180609 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -292,15 +292,25 @@ class CallTreeTransformer(transformer.Base): raise NotImplementedError( 'py_func with return values (unknown function)') else: + if anno.hasanno(node.func, anno.Basic.QN): + # Special-case a few builtins that otherwise go undetected. This + # normally doesn't pose a problem, but the dict built-in doesn't + # work with inspect.getargspec which is required for dynamic functions. + # Note: expecting this is resilient to aliasing (e.g. + # dict = an_evil_dict), because in those cases the regular mechanisms + # process a simple user function. + qn = anno.getanno(node.func, anno.Basic.QN) + # Add items to this list as needed. + if str(qn) in ('dict',): + return node + if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. - pass - elif self.context.recursive: + return node + + if self.context.recursive: node = self._insert_dynamic_conversion(node) - else: - # Unresolved functions are allowed in non-recursive mode. - pass return node diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 4299a8a9d59715d032222c47794bbb4393f34ce6..0417817a77e706fc0ce805f7391bea600f5fbb2d 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -24,103 +24,115 @@ from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno -class ContinueCanonicalizationTransformer(transformer.Base): - """Canonicalizes continue statements into additional conditionals.""" +# Tags for local state. +CONTROL_VAR_NAME = 'control_var_name' +CONTINUE_USED = 'continue_used' +GUARD_CREATED = 'guard_created' +CREATE_GUARD_NEXT = 'create_guard_next' - def __init__(self, context): - super(ContinueCanonicalizationTransformer, self).__init__(context) - # This is a stack structure, to correctly process nested loops. - self.continuation_uses = [] - def _create_continuation_check(self): - template = """ - if not var_name: - pass - """ - cond, = templates.replace(template, var_name=self.continuation_uses[-1][1]) - cond.body = [] - return cond +class ContinueCanonicalizationTransformer(transformer.Base): + """Canonicalizes continue statements into additional conditionals.""" - def _create_continuation_trigger(self): + def visit_Continue(self, node): + self.set_local(CONTINUE_USED, True) template = """ var_name = True """ - assign, = templates.replace( - template, var_name=self.continuation_uses[-1][1]) - return assign - - def _create_continuation_init(self): - template = """ - var_name = False - """ - assign, = templates.replace( - template, var_name=self.continuation_uses[-1][1]) - return assign - - def _visit_and_reindent_if_necessary(self, nodes): - reorganized_nodes = [] - current_dest = reorganized_nodes - continue_used_in_block = False - for i, n in enumerate(nodes): - # TODO(mdan): This could be optimized if control structures are simple. - self.continuation_uses[-1][0] = False - n = self.visit(n) - current_dest.append(n) - if self.continuation_uses[-1][0]: - continue_used_in_block = True - if i < len(nodes) - 1: # Last statement in block needs no protection. - cond = self._create_continuation_check() - current_dest.append(cond) - current_dest = cond.body - self.continuation_uses[-1][0] = continue_used_in_block - return reorganized_nodes - - def _process_loop_block(self, block, scope): - cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced) - self.continuation_uses.append([False, cont_var]) - block = self._visit_and_reindent_if_necessary(block) - if self.continuation_uses[-1][0]: - block.insert(0, self._create_continuation_init()) - self.continuation_uses.pop() - return block + return templates.replace( + template, var_name=self.get_local(CONTROL_VAR_NAME)) + + def _postprocess_statement(self, node): + # Example of how the state machine below works: + # + # 1| stmt # State: CONTINUE_USED = False + # | # Action: none + # 2| if cond: + # 3| continue # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = False + # | # Action: set CREATE_GUARD_NEXT = True + # 4| stmt # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = True + # | # Action: create `if not continue_used`, + # | # set GUARD_CREATED = True + # 5| stmt # State: CONTINUE_USED = True, GUARD_CREATED = True + # | # Action: none (will be wrapped under previously + # | # created if node) + + if self.get_local(CONTINUE_USED, False): + if self.get_local(GUARD_CREATED, False): + return node, None + + elif not self.get_local(CREATE_GUARD_NEXT, False): + self.set_local(CREATE_GUARD_NEXT, True) + return node, None + + else: + self.set_local(GUARD_CREATED, True) + template = """ + if not var_name: + original_node + """ + cond, = templates.replace( + template, + var_name=self.get_local(CONTROL_VAR_NAME), + original_node=node) + return cond, cond.body + return node, None + + def _visit_loop_body(self, node, nodes): + self.enter_local_scope() + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + continue_var = self.context.namer.new_symbol('continue_', scope.referenced) + self.set_local(CONTROL_VAR_NAME, continue_var) + + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + + if self.get_local(CONTINUE_USED, False): + template = """ + var_name = False + """ + control_var_init = templates.replace(template, var_name=continue_var) + nodes = control_var_init + nodes + + self.exit_local_scope() + return nodes + + def _visit_non_loop_body(self, nodes): + self.enter_local_scope(inherit=(CONTROL_VAR_NAME,)) + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + continue_used = self.get_local(CONTINUE_USED, False) + self.exit_local_scope(keep=(CONTINUE_USED,)) + return nodes, continue_used def visit_While(self, node): - self.generic_visit(node.test) - node.body = self._process_loop_block(node.body, - anno.getanno(node, - NodeAnno.BODY_SCOPE)) - for n in node.orelse: - self.generic_visit(n) + node.test = self.visit(node.test) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) return node def visit_For(self, node): - self.generic_visit(node.target) - self.generic_visit(node.iter) - node.body = self._process_loop_block(node.body, - anno.getanno(node, - NodeAnno.BODY_SCOPE)) - for n in node.orelse: - self.generic_visit(n) + node.target = self.generic_visit(node.target) + node.iter = self.generic_visit(node.iter) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) return node def visit_If(self, node): - if self.continuation_uses: - self.generic_visit(node.test) - node.body = self._visit_and_reindent_if_necessary(node.body) - continue_used_in_body = self.continuation_uses[-1][0] - node.orelse = self._visit_and_reindent_if_necessary(node.orelse) - self.continuation_uses[-1][0] = ( - continue_used_in_body or self.continuation_uses[-1][0]) - else: - node = self.generic_visit(node) + node.test = self.generic_visit(node.test) + node.body, continue_used_body = self._visit_non_loop_body(node.body) + node.orelse, continue_used_orelse = self._visit_non_loop_body(node.orelse) + self.set_local(CONTINUE_USED, continue_used_body or continue_used_orelse) return node - def visit_Continue(self, node): - self.continuation_uses[-1][0] = True - return self._create_continuation_trigger() - - def visit_Break(self, node): - assert False, 'break statement should be desugared at this point' + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body, _ = self._visit_non_loop_body(node.body) + return node def transform(node, namer): diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index 935a2786db0289c67860be2da97e3f554f12500c..d7ddbe8a04f64848d6ec21155d8d85f60e19d276 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Handles control flow statements: while, if.""" +"""Handles control flow statements: while, for, if.""" from __future__ import absolute_import from __future__ import division @@ -25,6 +25,7 @@ from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import cfg from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -47,9 +48,6 @@ class SymbolNamer(object): class ControlFlowTransformer(transformer.Base): """Transforms control flow structures like loops an conditionals.""" - def __init__(self, context): - super(ControlFlowTransformer, self).__init__(context) - def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if aliased_orig_names: @@ -98,30 +96,63 @@ class ControlFlowTransformer(transformer.Base): body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) - - if body_scope.created - orelse_scope.created: - raise ValueError( - 'The if branch creates new symbols that the else branch does not.') - if orelse_scope.created - body_scope.created: - raise ValueError( - 'The else branch creates new symbols that the if branch does not.') - - modified = tuple(body_scope.modified | orelse_scope.modified) - all_referenced = body_scope.referenced | orelse_scope.referenced + body_defs = body_scope.created | body_scope.modified + orelse_defs = orelse_scope.created | orelse_scope.modified + live = anno.getanno(node, 'live_out') + + # We'll need to check if we're closing over variables that are defined + # elsewhere in the function + # NOTE: we can only detect syntactic closure in the scope + # of the code passed in. If the AutoGraph'd function itself closes + # over other variables, this analysis won't take that into account. + defined = anno.getanno(node, 'defined_in') + + # We only need to return variables that are + # - modified by one or both branches + # - live (or has a live parent) at the end of the conditional + modified = [] + for def_ in body_defs | orelse_defs: + def_with_parents = set((def_,)) | def_.support_set + if live & def_with_parents: + modified.append(def_) + + # We need to check if live created variables are balanced + # in both branches + created = live & (body_scope.created | orelse_scope.created) + + # The if statement is illegal if there are variables that are created, + # that are also live, but both branches don't create them. + if created: + if created != (body_scope.created & live): + raise ValueError( + 'The main branch does not create all live symbols that the else ' + 'branch does.') + if created != (orelse_scope.created & live): + raise ValueError( + 'The else branch does not create all live symbols that the main ' + 'branch does.') # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. - need_alias = ( - (body_scope.modified | orelse_scope.modified) - - (body_scope.created | orelse_scope.created)) - aliased_orig_names = tuple(need_alias) - aliased_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), all_referenced) - for s in aliased_orig_names) - alias_map = dict(zip(aliased_orig_names, aliased_new_names)) - node_body = ast_util.rename_symbols(node.body, alias_map) - node_orelse = ast_util.rename_symbols(node.orelse, alias_map) + # We will alias variables independently for body and orelse scope, + # because different branches might write different variables. + aliased_body_orig_names = tuple(body_scope.modified - body_scope.created) + aliased_orelse_orig_names = tuple(orelse_scope.modified - + orelse_scope.created) + aliased_body_new_names = tuple( + self.context.namer.new_symbol(s.ssf(), body_scope.referenced) + for s in aliased_body_orig_names) + aliased_orelse_new_names = tuple( + self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) + for s in aliased_orelse_orig_names) + + alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) + alias_orelse_map = dict( + zip(aliased_orelse_orig_names, aliased_orelse_new_names)) + + node_body = ast_util.rename_symbols(node.body, alias_body_map) + node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) if not modified: # When the cond would return no value, we leave the cond called without @@ -134,26 +165,47 @@ class ControlFlowTransformer(transformer.Base): else: results = gast.Tuple([s.ast() for s in modified], None) - body_name = self.context.namer.new_symbol('if_true', all_referenced) - orelse_name = self.context.namer.new_symbol('if_false', all_referenced) + body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) + orelse_name = self.context.namer.new_symbol('if_false', + orelse_scope.referenced) if modified: - body_returns = tuple( - alias_map[s] if s in aliased_orig_names else s for s in modified) + + def build_returns(aliased_names, alias_map, scope): + """Builds list of return variables for a branch of a conditional.""" + returns = [] + for s in modified: + if s in aliased_names: + returns.append(alias_map[s]) + else: + if s not in scope.created | defined: + raise ValueError( + 'Attempting to return variable "%s" from the true branch of ' + 'a conditional, but it was not closed over, or created in ' + 'this branch.' % str(s)) + else: + returns.append(s) + return tuple(returns) + + body_returns = build_returns(aliased_body_orig_names, alias_body_map, + body_scope) + orelse_returns = build_returns(aliased_orelse_orig_names, + alias_orelse_map, orelse_scope) + else: - body_returns = templates.replace('tf.ones(())')[0].value + body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, - aliased_orig_names=tuple(aliased_orig_names), - aliased_new_names=tuple(aliased_new_names), + aliased_orig_names=tuple(aliased_body_orig_names), + aliased_new_names=tuple(aliased_body_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, - aliased_orig_names=tuple(aliased_orig_names), - aliased_new_names=tuple(aliased_new_names), + aliased_orig_names=tuple(aliased_orelse_orig_names), + aliased_new_names=tuple(aliased_orelse_new_names), body=node_orelse, - returns=body_returns) + returns=orelse_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) @@ -284,6 +336,7 @@ class ControlFlowTransformer(transformer.Base): def transform(node, context): - t = ControlFlowTransformer(context) - node = t.visit(node) + cfg.run_analyses(node, cfg.Liveness(context)) + cfg.run_analyses(node, cfg.Defined(context)) + node = ControlFlowTransformer(context).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index c5610b16b4e5de374f404307d3583660707d5e0b..9d23d9b5b7e8e8480e04fccc1c8c81799abf382b 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -22,6 +22,7 @@ from tensorflow.contrib.autograph.converters import control_flow from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test @@ -41,7 +42,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.while_loop) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual((10, 5, 5), sess.run(result.test_fn(constant_op.constant(5)))) @@ -56,7 +57,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.while_loop) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5)))) @@ -74,7 +75,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.cond) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual((-1, 0), sess.run(result.test_fn(constant_op.constant(1)))) @@ -91,10 +92,95 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.cond) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) + def test_imbalanced_aliasing(self): + + def test_fn(n): + if n > 0: + n = 3 + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond) as result: + with self.test_session() as sess: + self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2)))) + self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + + def test_ignore_unread_variable(self): + + def test_fn(n): + b = 3 # pylint: disable=unused-variable + if n > 0: + b = 4 + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3)))) + self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + + def test_handle_temp_variable(self): + + def test_fn_using_temp(x, y, w): + if x < y: + z = x + y + else: + w = 2 + tmp = w + z = x - tmp + return z, w + + node = self.parse_and_analyze(test_fn_using_temp, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + z, w = sess.run( + result.test_fn_using_temp( + constant_op.constant(-3), constant_op.constant(3), + constant_op.constant(3))) + self.assertEqual(0, z) + self.assertEqual(3, w) + z, w = sess.run( + result.test_fn_using_temp( + constant_op.constant(3), constant_op.constant(-3), + constant_op.constant(3))) + self.assertEqual(1, z) + self.assertEqual(2, w) + + def test_fn_ignoring_temp(x, y, w): + if x < y: + z = x + y + else: + w = 2 + tmp = w + z = x - tmp + return z + + node = self.parse_and_analyze(test_fn_ignoring_temp, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + z = sess.run( + result.test_fn_ignoring_temp( + constant_op.constant(-3), constant_op.constant(3), + constant_op.constant(3))) + self.assertEqual(0, z) + z = sess.run( + result.test_fn_ignoring_temp( + constant_op.constant(3), constant_op.constant(-3), + constant_op.constant(3))) + self.assertEqual(1, z) + def test_simple_for(self): def test_fn(l): diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py index b49521b2c328f418828a5e92890aa1b169384b70..c15dfff9e8ebd8b96fd4aff82459a6fd7d0ac8ab 100644 --- a/tensorflow/contrib/autograph/converters/lists.py +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -33,82 +33,193 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.python.framework import dtypes +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +# Tags for local state. +POP_USES = 'pop_uses' class ListTransformer(transformer.Base): """Converts lists and related operations to their TF counterpart.""" - def _empty_list(self, node): - if not anno.hasanno(node, 'element_type'): - raise NotImplementedError( - 'type inference for empty lists is not yet supported; ' - 'use set_element_type(, ) to continue') - dtype = anno.getanno(node, 'element_type') - if not isinstance(dtype, dtypes.DType): - # TODO(mdan): Allow non-TF dtypes? - # That would be consistent with the dynamic dispatch pattern, but - # we must make sure that doesn't become confusing. - raise NotImplementedError('element type "%s" not yet supported' % dtype) - - dtype_name = dtype.name - # TODO(mdan): Does it ever make sense not to use tensor lists? + def visit_List(self, node): + node = self.generic_visit(node) template = """ - tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True) + ag__.new_list(elements) """ - return templates.replace_as_expression(template, dtype_name=dtype_name) + return templates.replace_as_expression(template, elements=node) - def _pre_populated_list(self, node): - raise NotImplementedError('pre-populated lists') + def _replace_append_call(self, node): + assert len(node.args) == 1 + assert isinstance(node.func, gast.Attribute) + template = """ + target = ag__.list_append(target, element) + """ + return templates.replace( + template, + target=node.func.value, + element=node.args[0]) + + def _replace_pop_call(self, node): + # Expressions that use pop() are converted to a statement + expression. + # + # For example: + # + # print(target.pop()) + # + # ... is converted to: + # + # target, target_pop = ag__.list_pop(target) + # print(target_pop) + # + # Here, we just generate the variable name and swap it in, + # and _generate_pop_operation will handle the rest. + # + # Multiple uses of pop() are allowed: + # + # print(tartget.pop(), target.pop()) + # print(tartget.pop().pop()) + # + assert isinstance(node.func, gast.Attribute) + scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) + target_node = node.func.value + + # Attempt to use a related name if can get one. Otherwise use something + # generic. + if anno.hasanno(target_node, anno.Basic.QN): + target_name = anno.getanno(target_node, anno.Basic.QN).ssf() + else: + target_name = 'list' + pop_var_name = self.context.namer.new_symbol(target_name, scope.referenced) + + pop_uses = self.get_local(POP_USES, []) + pop_uses.append((node, pop_var_name)) + self.set_local(POP_USES, pop_uses) + + return templates.replace_as_expression('var_name', var_name=pop_var_name) + + def _replace_stack_call(self, node): + assert len(node.args) == 1 + dtype = anno.getanno( + node.args[0], + 'element_type', + default=templates.replace_as_expression('None')) + template = """ + ag__.list_stack( + target, + opts=ag__.ListStackOpts( + element_dtype=dtype, + original_call=orig_call)) + """ + return templates.replace_as_expression( + template, + dtype=dtype, + target=node.args[0], + orig_call=node.func) - def visit_Expr(self, node): + def visit_Call(self, node): node = self.generic_visit(node) - if isinstance(node.value, gast.Call): - call_node = node.value - - if not anno.hasanno(call_node.func, anno.Basic.QN): - return node - qn = anno.getanno(call_node.func, anno.Basic.QN) - - if qn.qn[-1] == 'append' and (len(call_node.args) == 1): - template = """ - target = ag__.utils.dynamic_list_append(target, element) - """ - node = templates.replace( - template, - target=qn.parent.ast(), - element=call_node.args[0]) + + # TODO(mdan): This is insufficient if target is a function argument. + # In the case of function arguments, we need to add the list to the + # function's return value, because it is being modified. + # TODO(mdan): Checking just the name is brittle, can it be improved? + if isinstance(node.func, gast.Attribute): + func_name = node.func.attr + if func_name == 'append' and (len(node.args) == 1): + node = self._replace_append_call(node) + elif func_name == 'pop' and (len(node.args) <= 1): + node = self._replace_pop_call(node) + elif func_name == 'stack' and (len(node.args) == 1): + node = self._replace_stack_call(node) + return node - def _replace_list_constructors(self, targets, values): - for target in targets: - if (isinstance(target, (gast.Tuple, gast.List)) and - isinstance(values, (gast.Tuple, gast.List))): - n_targets = len(target.elts) - for i in range(n_targets): - target_el, value_el = target.elts[i], values.elts[i] - values.elts[i] = self._replace_list_constructors( - (target_el,), value_el) - return values - if isinstance(values, gast.List): - if values.elts: - return self._pre_populated_list(values) - else: - return self._empty_list(values) - return values - - def visit_Assign(self, node): - node = self.generic_visit(node) + def _generate_pop_operation(self, original_call_node, pop_var_name): + assert isinstance(original_call_node.func, gast.Attribute) + + if original_call_node.args: + pop_element = original_call_node.args[0] + else: + pop_element = parser.parse_expression('None') + # The call will be something like "target.pop()", and the dtype is hooked to + # target, hence the func.value. + dtype = anno.getanno( + original_call_node.func.value, + 'element_type', + default=templates.replace_as_expression('None')) + shape = anno.getanno( + original_call_node.func.value, + 'element_shape', + default=templates.replace_as_expression('None')) + + template = """ + target, pop_var_name = ag__.list_pop( + target, element, + opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) + """ + return templates.replace( + template, + target=original_call_node.func.value, + pop_var_name=pop_var_name, + element=pop_element, + dtype=dtype, + shape=shape) + + def _postprocess_statement(self, node): + """Inserts any separate pop() calls that node may use.""" + pop_uses = self.get_local(POP_USES, None) + if pop_uses: + replacements = [] + for original_call_node, pop_var_name in pop_uses: + replacements.extend( + self._generate_pop_operation(original_call_node, pop_var_name)) + replacements.append(node) + node = replacements + self.exit_local_scope() + return node, None + + # TODO(mdan): Should we have a generic visit_block instead? + # Right now it feels that a visit_block would add too much magic that's + # hard to follow. + + def _visit_and_process_block(self, block): + return self.visit_block( + block, + before_visit=self.enter_local_scope, + after_visit=self._postprocess_statement) + + def visit_FunctionDef(self, node): + node.args = self.generic_visit(node.args) + node.decorator_list = self.visit_block(node.decorator_list) + node.body = self._visit_and_process_block(node.body) + return node + + def visit_For(self, node): + node.target = self.visit(node.target) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_While(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_If(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node - # Only convert lists when they are assigned to a variable, e.g.: - # l = [] - # TODO(mdan): A similar pattern exists in type_info.py - # We should add a generic "unpack_assignment" function to the base - # transformer, that has the same effect as applying some logic to the SSA - # form. - node.value = self._replace_list_constructors(node.targets, node.value) + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body = self._visit_and_process_block(node.body) return node diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py index 74c6dc64f197f75eb3e66c01fb078467e8e8ea89..9f18ab9f44dd8c3f341a02b950f75317c676eff8 100644 --- a/tensorflow/contrib/autograph/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -22,74 +22,126 @@ from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import lists from tensorflow.python.framework import dtypes -from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops from tensorflow.python.platform import test class ListTest(converter_test_base.TestCase): - def test_empty_annotated_list(self): + def test_empty_list(self): def test_fn(): - l = [] - utils.set_element_type(l, dtypes.int32) - l.append(1) - return l + return [] - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = self.parse_and_analyze(test_fn, {}) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: - # TODO(mdan): Attach these additional modules automatically. - result.utils = utils - result.dtypes = dtypes + with self.compiled(node) as result: + tl = result.test_fn() + # Empty tensor lists cannot be evaluated or stacked. + self.assertTrue(isinstance(tl, ops.Tensor)) + self.assertEqual(tl.dtype, dtypes.variant) + + def test_initialized_list(self): + + def test_fn(): + return [1, 2, 3] + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: with self.test_session() as sess: - self.assertAllEqual([1], sess.run(result.test_fn().stack())) + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) - def test_empty_annotated_lists_unpacked(self): + def test_list_append(self): def test_fn(): - l, m = [], [] - utils.set_element_type(l, dtypes.int32) - utils.set_element_type(m, dtypes.int32) - l.append(1) - m.append(2) - return l, m + l = [1] + l.append(2) + l.append(3) + return l - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = self.parse_and_analyze(test_fn, {}) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: + with self.compiled(node) as result: + with self.test_session() as sess: + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) + + def test_list_pop(self): + + def test_fn(): + l = [1, 2, 3] + utils.set_element_type(l, dtypes.int32, ()) + s = l.pop() + return s, l + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: result.utils = utils result.dtypes = dtypes with self.test_session() as sess: - res_l, res_m = result.test_fn() - self.assertEqual([1], sess.run(res_l.stack())) - self.assertEqual([2], sess.run(res_m.stack())) + ts, tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2]) + self.assertAllEqual(sess.run(ts), 3) + + def test_double_list_pop(self): - def test_empty_annotated_lists_list_unpacked(self): + def test_fn(l): + s = l.pop().pop() + return s + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: + test_input = [1, 2, [1, 2, 3]] + # TODO(mdan): Pass a list of lists of tensor when we fully support that. + # For now, we just pass a regular Python list of lists just to verify that + # the two pop calls are sequenced properly. + self.assertAllEqual(result.test_fn(test_input), 3) + + def test_list_stack(self): + + tf = None # Will be replaced with a mock. def test_fn(): - [l, m] = [], [] + l = [1, 2, 3] utils.set_element_type(l, dtypes.int32) - utils.set_element_type(m, dtypes.int32) - l.append(1) - m.append(2) - return l, m - - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + return tf.stack(l) + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: + with self.compiled(node, array_ops.stack, dtypes.int32) as result: result.utils = utils result.dtypes = dtypes with self.test_session() as sess: - res_l, res_m = result.test_fn() - self.assertEqual([1], sess.run(res_l.stack())) - self.assertEqual([2], sess.run(res_m.stack())) + self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3]) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py new file mode 100644 index 0000000000000000000000000000000000000000..85aeda9c4164eb70329bd50f789eea5441c8fc87 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -0,0 +1,83 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for slice operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer + + +class SliceTransformer(transformer.Base): + """Converts slicing operations to their TF counterpart. + + Currently, relying on the default slice operator that Tensor uses is + insufficient, because TensorArray and tensor lists use dedicated index read + and write functions. + """ + + def _process_single_assignment(self, target, value): + if not isinstance(target, gast.Subscript): + return None + + template = """ + target = ag__.set_item(target, key, item) + """ + return templates.replace( + template, target=target.value, key=target.slice, item=value) + + def visit_Assign(self, node): + node = self.generic_visit(node) + # TODO(mdan): Support unpackings and multiple assignments. + if len(node.targets) != 1: + raise NotImplementedError('multiple assignment') + replacement = self._process_single_assignment(node.targets[0], node.value) + if replacement is not None: + return replacement + return node + + def visit_Subscript(self, node): + node = self.generic_visit(node) + if not isinstance(node.slice, gast.Index): + # TODO(mdan): It might make more sense to wave them through. + raise NotImplementedError('non-index slice') + + if not isinstance(node.ctx, gast.Load): + # Index writes are handled at a higher level, one at which the rvalue is + # also available. + return node + + dtype = anno.getanno( + node.value, + 'element_type', + default=templates.replace_as_expression('None')) + + template = """ + ag__.get_item( + target, + key, + opts=ag__.GetItemOpts(element_dtype=dtype)) + """ + return templates.replace_as_expression( + template, target=node.value, key=node.slice, dtype=dtype) + + +def transform(node, context): + return SliceTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2d7e1ea1a6c46fcc3a2c6972a24507646ef858 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices_test.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import slices +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SliceTest(converter_test_base.TestCase): + + def test_index_access(self): + + def test_fn(l): + utils.set_element_type(l, dtypes.int32) + return l[1] + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = slices.transform(node, self.ctx) + + with self.compiled(node, dtypes.int32) as result: + result.utils = utils + result.dtypes = dtypes + with self.test_session() as sess: + tl = list_ops.tensor_list_from_tensor( + [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) + y = result.test_fn(tl) + self.assertEqual(2, sess.run(y)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD index 54424e26472b8466b8fe68ea848b5463c10224c9..02f16ae1875d6bd1fb87d19f8bfc5cae900391dd 100644 --- a/tensorflow/contrib/autograph/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -20,7 +20,9 @@ py_library( "api.py", "config.py", "conversion.py", + "directives.py", "naming.py", + "special_functions.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -69,3 +71,13 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "special_functions_test", + srcs = ["special_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":impl", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 55a30dc127957b2a9caa053db843380c94bacfbf..7802bbbe27ec5fed891440af2f589801918b3bdd 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -38,6 +38,7 @@ from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.contrib.autograph.converters import name_scopes from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.contrib.autograph.converters import single_return +from tensorflow.contrib.autograph.converters import slices from tensorflow.contrib.autograph.impl import config from tensorflow.contrib.autograph.impl import naming from tensorflow.contrib.autograph.pyct import ast_util @@ -371,6 +372,8 @@ def node_to_graph(node, ctx, nocompile_decorators): # TODO(mdan): Clean this up. # Some intermediate analyses are not required, and some comments got orphaned. + # TODO(mdan): We may assume all converters require analysis to be re-done. + # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? @@ -393,6 +396,8 @@ def node_to_graph(node, ctx, nocompile_decorators): node = _static_analysis_pass(node, ctx) node = lists.transform(node, ctx) + node = _static_analysis_pass(node, ctx) + node = slices.transform(node, ctx) node = builtin_functions.transform(node, ctx) node = _static_analysis_pass(node, ctx) diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index 5edd8e74a8899a25fb51e2a4e133f3cb7933fa26..bc61498b5422f5e130bbfeef935d0a796b4f5922 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -24,7 +24,7 @@ from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.impl import api from tensorflow.contrib.autograph.impl import conversion from tensorflow.python.framework import constant_op -from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.keras.engine import training from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/impl/directives.py b/tensorflow/contrib/autograph/impl/directives.py new file mode 100644 index 0000000000000000000000000000000000000000..aabe5d99394a0cb921196d1c6a6b2a9496ea7545 --- /dev/null +++ b/tensorflow/contrib/autograph/impl/directives.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Directives are special no-op functions that serve as compilation markers. + +They provide static information like type hints, compilation and TensorFlow +overrides. + +These serve as annotations in the compiled code, allowing the user some control +over the compilation process. They have no functional role at runtime. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +UNSPECIFIED = object() + + +def set_element_type(entity, dtype, shape=UNSPECIFIED): + """Indicates that the entity is expected hold items of specified type/shape. + + The staged TensorFlow ops will reflect and assert this data type. Ignored + otherwise. + + Args: + entity: The entity to annotate. + dtype: TensorFlow dtype value to assert for entity. + shape: Optional shape to assert for entity. + """ + del entity + del dtype + del shape + + +def set_loop_options( + parallel_iterations=UNSPECIFIED, + back_prop=UNSPECIFIED, + swap_memory=UNSPECIFIED, + maximum_iterations=UNSPECIFIED): + """Specifies additional arguments to be passed to the enclosing while_loop. + + The parameters apply to and only to the immediately enclosing loop. It only + has effect if the loop is staged as a TF while_loop; otherwise the parameters + have no effect. + + Args: + parallel_iterations: See tf.while_loop. + back_prop: See tf.while_loop. + swap_memory: See tf.while_loop. + maximum_iterations: See tf.while_loop. + """ + del parallel_iterations + del back_prop + del swap_memory + del maximum_iterations diff --git a/tensorflow/contrib/autograph/impl/special_functions.py b/tensorflow/contrib/autograph/impl/special_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a8177c44c88217560fb7f72c77d3ac1aa0c9ec --- /dev/null +++ b/tensorflow/contrib/autograph/impl/special_functions.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Special functions that only make sense for AutoGraph. + +These functions are meant to ensure feature parity between Python and AutoGraph, +so that the exact same code works in both modes. In general, AutoGraph will +replace these calls. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import data_structures + + +def stack(list_or_tensor, element_dtype=None): + """Stacks the input, if it admits the notion of stacking. No-op otherwise. + + For example, a list of tensors can be stacked into a larger tensor. This + function is similar to tf.stack, but it accepts non-lists and lists of + non-tensors as arguments. In the latter case, the function does nothing. + + Args: + list_or_tensor: Any entity. + element_dtype: Optional dtype for the elements in the list. Required if the + input is stackable, and the list is untyped. + + Returns: + If the input is stackable, a new object representing the stacked inputs. + Otherwise it returns list_or_tensor unchanged. + """ + return data_structures.list_stack( + list_or_tensor, + data_structures.ListStackOpts( + element_dtype=element_dtype, original_call=lambda x: x)) diff --git a/tensorflow/contrib/autograph/impl/special_functions_test.py b/tensorflow/contrib/autograph/impl/special_functions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9b52d2a59b5a3e3c92f11343197379c773ecc828 --- /dev/null +++ b/tensorflow/contrib/autograph/impl/special_functions_test.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for special_functions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.impl import special_functions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SpecialFunctionsTest(test.TestCase): + + def test_basic(self): + self.assertEqual(special_functions.stack(1), 1) + self.assertListEqual(special_functions.stack([1, 2, 3]), [1, 2, 3]) + # TODO(mdan): This should probably forward to tf.stack. + self.assertTrue( + isinstance( + special_functions.stack( + [constant_op.constant(1), + constant_op.constant(2)]), list)) + + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor( + t, element_shape=constant_op.constant([], dtype=dtypes.int32)) + self.assertTrue( + tensor_util.is_tensor( + special_functions.stack(l, element_dtype=dtypes.float32))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 18bfec5d9c69912f90414c51ac63ba540cf4d5fc..0c6ab65505ee03e19588adae73d3134399a34b65 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -22,7 +22,7 @@ py_library( "__init__.py", "control_flow.py", "data_structures.py", - "dispatch_context.py", + "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -52,3 +52,13 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index 38b761d97d54bdaee4da91269964469b482895ae..c900fd6af2ea5dfb419f731ee8d8822d68424b27 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -28,6 +28,10 @@ closures for the body. # - the names used in the Python docs, if the operator is a function (e.g. # list_ and x for append, see # https://docs.python.org/3.7/tutorial/datastructures.html) +# +# All operators may accept a final argument named "opts", of a type that +# subclasses namedtuple and contains any arguments that are only required +# for some specializations of the operator. from __future__ import absolute_import from __future__ import division @@ -35,3 +39,12 @@ from __future__ import print_function from tensorflow.contrib.autograph.operators.control_flow import for_stmt from tensorflow.contrib.autograph.operators.control_flow import while_stmt +from tensorflow.contrib.autograph.operators.data_structures import list_append +from tensorflow.contrib.autograph.operators.data_structures import list_pop +from tensorflow.contrib.autograph.operators.data_structures import list_stack +from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts +from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts +from tensorflow.contrib.autograph.operators.data_structures import new_list +from tensorflow.contrib.autograph.operators.slices import get_item +from tensorflow.contrib.autograph.operators.slices import GetItemOpts +from tensorflow.contrib.autograph.operators.slices import set_item diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py index c862306baa9e8114a71a26323ddcbd35c8592c55..06d8727b0fcc30b532b3f11281cd1a83c51ac8bc 100644 --- a/tensorflow/contrib/autograph/operators/data_structures.py +++ b/tensorflow/contrib/autograph/operators/data_structures.py @@ -18,39 +18,250 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables + + +# TODO(mdan): Once control flow supports objects, repackage as a class. + + +def new_list(iterable=None): + """The list constructor. + + Args: + iterable: Optional elements to fill the list with. + + Returns: + A list-like object. The exact return value depends on the initial elements. + """ + if iterable: + elements = tuple(iterable) + else: + elements = () + + # TODO(mdan): Extend these criteria. + if any(isinstance(el, variables.Variable) for el in elements): + return _py_list_new(elements) + return _tf_tensor_list_new(elements) -# TODO(mdan): Add support for TensorList once functional. -# TODO(mdan): Add primitives for empty list, list with elements. +def _tf_tensor_list_new(elements): + """Overload of new_list that stages a Tensor list creation.""" + elements = tuple(ops.convert_to_tensor(el) for el in elements) + all_dtypes = set(el.dtype for el in elements) + if len(all_dtypes) == 1: + element_dtype = tuple(all_dtypes)[0] + else: + # Heterogeneous lists are ok. + element_dtype = dtypes.variant + + # TODO(mdan): This may fail for elements of variable shapes. + all_shapes = set(tuple(el.shape.as_list()) for el in elements) + if len(all_shapes) == 1: + element_shape = array_ops.shape(elements[0]) + else: + # Heterogeneous lists are ok. + element_shape = constant_op.constant(-1) # unknown shape, by convention + + l = list_ops.empty_tensor_list( + element_shape=element_shape, element_dtype=element_dtype) + for el in elements: + l = list_ops.tensor_list_push_back(l, el) + return l -def append(target, element): + +def _py_list_new(elements): + """Overload of new_list that creates a Python list.""" + return list(elements) + + +def list_append(list_, x): """The list append function. - Note: it is unspecified where target will be mutated or not. If target is - a TensorFlow entity, it will not be typically mutated. If target is a plain - list, it will be. In general, if the target is mutated then the return value + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value should point to the original entity. Args: - target: An entity that supports append semantics. - element: The element to append. + list_: An entity that supports append semantics. + x: The element to append. Returns: - Same as target, after the append was performed. + Same as list_, after the append was performed. + + Raises: + ValueError: if list_ is not of a known list-like type. """ - if isinstance(target, tensor_array_ops.TensorArray): - return _tf_tensorarray_append(target, element) + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_append(list_, x) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_append(list_, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) else: - return _py_append(target, element) + return _py_list_append(list_, x) + + +def _tf_tensor_list_append(list_, x): + """Overload of list_append that stages a Tensor list write.""" + def empty_list_of_elements_like_x(): + tensor_x = ops.convert_to_tensor(x) + return list_ops.empty_tensor_list( + element_shape=array_ops.shape(tensor_x), + element_dtype=tensor_x.dtype) + + list_ = control_flow_ops.cond( + list_ops.tensor_list_length(list_) > 0, + lambda: list_, + empty_list_of_elements_like_x, + ) + return list_ops.tensor_list_push_back(list_, x) + + +def _tf_tensorarray_append(list_, x): + """Overload of list_append that stages a TensorArray write.""" + return list_.write(list_.size(), x) + + +def _py_list_append(list_, x): + """Overload of list_append that executes a Python list append.""" + # Revert to the original call. + list_.append(x) + return list_ + + +class ListPopOpts( + collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))): + pass + + +def list_pop(list_, i, opts): + """The list pop function. + + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value + should point to the original entity. + + Args: + list_: An entity that supports pop semantics. + i: Optional index to pop from. May be None. + opts: A ListPopOpts. + + Returns: + Tuple (x, out_list_): + out_list_: same as list_, after the removal was performed. + x: the removed element value. + + Raises: + ValueError: if list_ is not of a known list-like type or the operation is + not supported for that type. + """ + assert isinstance(opts, ListPopOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + raise ValueError('TensorArray does not support item removal') + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_pop(list_, i, opts) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) + else: + return _py_list_pop(list_, i) + + +def _tf_tensor_list_pop(list_, i, opts): + """Overload of list_pop that stages a Tensor list pop.""" + if i is not None: + raise NotImplementedError('tensor lists only support removing from the end') + + if opts.element_dtype is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'type; use set_element_type to annotate it') + if opts.element_shape is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'shape; use set_element_type to annotate it') + list_out, x = list_ops.tensor_list_pop_back( + list_, element_dtype=opts.element_dtype) + x.set_shape(opts.element_shape) + return list_out, x + + +def _py_list_pop(list_, i): + """Overload of list_pop that executes a Python list append.""" + if i is None: + x = list_.pop() + else: + x = list_.pop(i) + return list_, x + + +# TODO(mdan): Look into reducing duplication between all these containers. +class ListStackOpts( + collections.namedtuple('ListStackOpts', + ('element_dtype', 'original_call'))): + pass + + +def list_stack(list_, opts): + """The list stack function. + + This does not have a direct correspondent in Python. The closest idiom to + this is tf.append or np.stack. It's different from those in the sense that it + accepts a Tensor list, rather than a list of tensors. It can also accept + TensorArray. When the target is anything else, the dispatcher will rely on + ctx.original_call for fallback. + + Args: + list_: An entity that supports append semantics. + opts: A ListStackOpts object. + + Returns: + The output of the stack operation, typically a Tensor. + """ + assert isinstance(opts, ListStackOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_stack(list_) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_stack(list_, opts) + else: + # No-op for primitive Tensor arguments. + return list_ + else: + return _py_list_stack(list_, opts) + + +def _tf_tensorarray_stack(list_): + """Overload of list_stack that stages a TensorArray stack.""" + return list_.stack() -def _tf_tensorarray_append(target, element): - """Overload of append that stages a TensorArray write at the last position.""" - return target.write(target.size(), element) +def _tf_tensor_list_stack(list_, opts): + """Overload of list_stack that stages a Tensor list write.""" + if opts.element_dtype is None: + raise ValueError('cannot stack a list without knowing its element type;' + ' use set_element_type to annotate it') + return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype) -def _py_append(target, element): - """Overload of append that executes a Python list append.""" - target.append(element) - return target +def _py_list_stack(list_, opts): + """Overload of list_stack that executes a Python list append.""" + # Revert to the original call. + return opts.original_call(list_) diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py index 577d28c34da39f1216669513c29a00ac07bec126..8bbb52d6c10b241ec754c7dea599fa15a869595f 100644 --- a/tensorflow/contrib/autograph/operators/data_structures_test.py +++ b/tensorflow/contrib/autograph/operators/data_structures_test.py @@ -19,25 +19,98 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.operators import data_structures +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test -class AppendTest(test.TestCase): +class ListTest(test.TestCase): - def test_tf_tensorarray(self): + def test_new_list_empty(self): + l = data_structures.new_list() + # Can't evaluate an empty list. + # TODO(mdan): sess.run should allow tf.variant maybe? + self.assertTrue(isinstance(l, ops.Tensor)) + + def test_new_list_tensor(self): + l = data_structures.new_list([3, 4, 5]) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4, 5]) + + def test_append_tensor_list(self): + l = data_structures.new_list() + x = constant_op.constant([1, 2, 3]) + l = data_structures.list_append(l, x) + + t = list_ops.tensor_list_stack(l, element_dtype=x.dtype) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [[1, 2, 3]]) + + def test_append_tensorarray(self): l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) - l1 = data_structures.append(l, 1) - l2 = data_structures.append(l1, 2) + l1 = data_structures.list_append(l, 1) + l2 = data_structures.list_append(l1, 2) with self.test_session() as sess: self.assertAllEqual(sess.run(l1.stack()), [1]) self.assertAllEqual(sess.run(l2.stack()), [1, 2]) - def test_python(self): + def test_append_python(self): l = [] - self.assertAllEqual(data_structures.append(l, 1), [1]) - self.assertAllEqual(data_structures.append(l, 2), [1, 2]) + self.assertAllEqual(data_structures.list_append(l, 1), [1]) + self.assertAllEqual(data_structures.list_append(l, 2), [1, 2]) + + def test_pop_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListPopOpts( + element_dtype=initial_list.dtype, + element_shape=(2,)) + + with self.assertRaises(NotImplementedError): + data_structures.list_pop(l, 0, opts) + + with self.test_session() as sess: + l, x = data_structures.list_pop(l, None, opts) + self.assertAllEqual(sess.run(x), [3, 4]) + + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[1, 2]]) + + def test_pop_python(self): + l = [1, 2, 3] + opts = data_structures.ListPopOpts(element_dtype=None, element_shape=()) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1, 2], 3)) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1], 2)) + + def test_stack_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListStackOpts( + element_dtype=initial_list.dtype, original_call=None) + + with self.test_session() as sess: + t = data_structures.list_stack(l, opts) + self.assertAllEqual(sess.run(t), sess.run(initial_list)) + + def test_stack_fallback(self): + + def dummy_function(l): + # Lazy person's mock: just transform the argument in a way in which we + # can check that this function was indeed called. + return [x * 2 for x in l] + + opts = data_structures.ListStackOpts( + element_dtype=None, original_call=dummy_function) + + self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4]) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py new file mode 100644 index 0000000000000000000000000000000000000000..04fbeb2f6e39234cad139442704fd7a8d0f56172 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -0,0 +1,133 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators specific to slicing operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops + + +# TODO(mdan): Support extended slices. + + +class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))): + pass + + +def get_item(target, i, opts): + """The slice read operator (i.e. __getitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports getitem semantics. + i: Index to read from. + opts: A GetItemOpts object. + + Returns: + The read element. + + Raises: + ValueError: if target is not of a supported type. + """ + assert isinstance(opts, GetItemOpts) + + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_get_item(target, i) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_get_item(target, i, opts) + else: + return _tf_tensor_get_item(target, i) + else: + return _py_get_item(target, i) + + +def _tf_tensorarray_get_item(target, i): + """Overload of get_item that stages a TensorArray read.""" + return target.read(i) + + +def _tf_tensor_list_get_item(target, i, opts): + """Overload of get_item that stages a Tensor list read.""" + if opts.element_dtype is None: + raise ValueError('cannot retrieve from a list without knowing its ' + 'element type; use set_element_type to annotate it') + x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype) + return x + + +def _tf_tensor_get_item(target, i): + """Overload of get_item that stages a Tensor (not Tensor list) read.""" + return target[i] + + +def _py_get_item(target, i): + """Overload of get_item that executes a Python list modification.""" + return target[i] + + +def set_item(target, i, x): + """The slice write operator (i.e. __setitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports setitem semantics. + i: Index to modify. + x: The new element value. + + Returns: + Same as target, after the update was performed. + + Raises: + ValueError: if target is not of a supported type. + """ + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_set_item(target, i, x) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_set_item(target, i, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % target) + else: + return _py_set_item(target, i, x) + + +def _tf_tensorarray_set_item(target, i, x): + """Overload of set_item that stages a TensorArray write.""" + return target.write(i, x) + + +def _tf_tensor_list_set_item(target, i, x): + """Overload of set_item that stages a Tensor list update.""" + return list_ops.tensor_list_set_item(target, i, x) + + +def _py_set_item(target, i, x): + """Overload of set_item that executes a Python list modification.""" + target[i] = x + return target diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d4aacb9d2015fec56a8df5ad85a20b733765ba26 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import slices +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SlicesTest(test.TestCase): + + def test_set_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + l = slices.set_item(l, 0, [5, 6]) + + with self.test_session() as sess: + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]]) + + def test_get_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + t = slices.get_item( + l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype)) + + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index 796ab445c74128e1123e24b67c288e0e3c5ca24c..989b821e53a5cefbe39095e669f9a9e0bec65b8a 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -130,6 +130,7 @@ py_test( name = "transformer_test", srcs = ["transformer_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":pyct", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py index cc4a7edf02ed7556c9a552d8730e4c7875038c83..ae861627fd65cca057e7bf1af41424e605d4b7a1 100644 --- a/tensorflow/contrib/autograph/pyct/anno.py +++ b/tensorflow/contrib/autograph/pyct/anno.py @@ -46,8 +46,15 @@ class Basic(NoValue): '`name_map` allows renaming symbols.') -def getanno(node, key, field_name='___pyct_anno'): - return getattr(node, field_name)[key] +FAIL = object() + + +def getanno(node, key, default=FAIL, field_name='___pyct_anno'): + if (default is FAIL or + (hasattr(node, field_name) and (key in getattr(node, field_name)))): + return getattr(node, field_name)[key] + else: + return default def hasanno(node, key, field_name='___pyct_anno'): @@ -73,5 +80,9 @@ def delanno(node, key, field_name='___pyct_anno'): def copyanno(from_node, to_node, key, field_name='___pyct_anno'): - if hasanno(from_node, key, field_name): - setanno(to_node, key, getanno(from_node, key, field_name), field_name) + if hasanno(from_node, key, field_name=field_name): + setanno( + to_node, + key, + getanno(from_node, key, field_name=field_name), + field_name=field_name) diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py index 1d4d9d119e0c45c4bf9dd4e5b8156766489a2e4d..f2c0c8cf05ca4b3671eb653ce56f6da61de54aee 100644 --- a/tensorflow/contrib/autograph/pyct/anno_test.py +++ b/tensorflow/contrib/autograph/pyct/anno_test.py @@ -38,12 +38,14 @@ class AnnoTest(test.TestCase): anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) - self.assertEqual(3, anno.getanno(node, 'foo')) + self.assertEqual(anno.getanno(node, 'foo'), 3) + self.assertEqual(anno.getanno(node, 'bar', default=7), 7) anno.delanno(node, 'foo') self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') + self.assertIsNone(anno.getanno(node, 'foo', default=None)) def test_copyanno(self): node_1 = ast.Name() diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py index 583cf7ecd7bce31c55de58361ab5295abb5d6707..da07013cf4f4309c0e24adda3017575d942861b7 100644 --- a/tensorflow/contrib/autograph/pyct/qual_names.py +++ b/tensorflow/contrib/autograph/pyct/qual_names.py @@ -205,6 +205,7 @@ class QnResolver(gast.NodeTransformer): return node def visit_Subscript(self, node): + # TODO(mdan): This may no longer apply if we overload getitem. node = self.generic_visit(node) s = node.slice if not isinstance(s, gast.Index): @@ -216,7 +217,11 @@ class QnResolver(gast.NodeTransformer): elif isinstance(s.value, gast.Str): subscript = QN(StringLiteral(s.value.s)) else: - subscript = anno.getanno(node.slice.value, anno.Basic.QN) + # The index may be an expression, case in which a name doesn't make sense. + if anno.hasanno(node.slice.value, anno.Basic.QN): + subscript = anno.getanno(node.slice.value, anno.Basic.QN) + else: + return node if anno.hasanno(node.value, anno.Basic.QN): anno.setanno(node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py index 230e4cc0f3311ac5ad1e80c2591896ee48866280..ad97fdfa8e78d1fd4c38724612d83519c6609cce 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -135,8 +135,7 @@ class CfgBuilder(gast.NodeVisitor): # Handle the body self.visit_statements(node.body) body_exit = self.current_leaves - self.current_leaves = [] - self.current_leaves.append(test) + self.current_leaves = [test] # Handle the orelse self.visit_statements(node.orelse) self.current_leaves.extend(body_exit) @@ -149,12 +148,15 @@ class CfgBuilder(gast.NodeVisitor): self.continue_.append([]) # Handle the body self.visit_statements(node.body) + body_exit = self.current_leaves self.current_leaves.extend(self.continue_.pop()) self.set_current_leaves(test) # Handle the orelse self.visit_statements(node.orelse) # The break statements and the test go to the next node self.current_leaves.extend(self.break_.pop()) + # Body and orelse statements can reach out of the loop + self.current_leaves.extend(body_exit) def visit_For(self, node): iter_ = CfgNode(node.iter) @@ -162,9 +164,15 @@ class CfgBuilder(gast.NodeVisitor): self.break_.append([]) self.continue_.append([]) self.visit_statements(node.body) + body_exit = self.current_leaves self.current_leaves.extend(self.continue_.pop()) self.set_current_leaves(iter_) + # Handle the orelse + self.visit_statements(node.orelse) + # The break statements and the test go to the next node self.current_leaves.extend(self.break_.pop()) + # Body and orelse statements can reach out of the loop + self.current_leaves.extend(body_exit) def visit_Break(self, node): self.break_[-1].extend(self.current_leaves) @@ -395,7 +403,13 @@ class Liveness(Backward): super(Liveness, self).__init__('live', context) def get_gen_kill(self, node, _): + # A variable's parents are live if it is live + # e.g. x is live if x.y is live. This means gen needs to return + # all parents of a variable (if it's an Attribute or Subscript). + # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x) gen = activity.get_read(node.value, self.context) + gen = functools.reduce(lambda left, right: left | right.support_set, gen, + gen) kill = activity.get_updated(node.value, self.context) return gen, kill diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py index af7eaf30e8d403acc18d79ac1dd9e98673c333a2..fc07fa3447b23c0595a5893329de8a2d7055ca15 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py @@ -115,20 +115,27 @@ class CFGTest(test.TestCase): if_body = body[0].body self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) - # TODO(alexbw): b/73926938 split this test up - def test_live(self): + def _get_live_annotated_fnbody(self, f): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Liveness(ctx)) + body = node.body[0].body + return body - def get_live_annotated_fnbody(f): - node, ctx = self._parse_and_analyze(f, {}) - cfg.run_analyses(node, cfg.Liveness(ctx)) - body = node.body[0].body - return body + def test_live_straightline(self): def f1(x): a = g(x) # pylint: disable=undefined-variable b = h(a) # pylint: disable=undefined-variable, unused-variable return x + body = self._get_live_annotated_fnbody(f1) + self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) + self._check_anno_matches(body[2], 'live_out', ()) + + def test_live_stacked_conds_with_else(self): + def f2(x, a): # pylint: disable=unused-argument if a > 0: # x should not be live x = 0 @@ -137,6 +144,12 @@ class CFGTest(test.TestCase): else: x = 2 + body = self._get_live_annotated_fnbody(f2) + self._check_anno_matches(body[0], 'live_in', ('a')) + self._check_anno_matches(body[1], 'live_in', ('a')) + + def test_live_stacked_conds(self): + def f3(x, a): if a > 0: # x and a should be live x = 0 @@ -144,58 +157,58 @@ class CFGTest(test.TestCase): x = 1 return x # x should be live + body = self._get_live_annotated_fnbody(f3) + self._check_anno_matches(body[0], 'live_in', ('a', 'x')) + self._check_anno_matches(body[1], 'live_in', ('a', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + + def test_live_possibly_unused_cond(self): + def f4(x, a): if a > 0: # x should be live x = 0 x += 1 + body = self._get_live_annotated_fnbody(f4) + self._check_anno_matches(body[0], 'live_in', ('x', 'a')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + def test_live_attribute_in_cond(self): + def f5(x, a): if a > 0: # x.y should be live x.y = 0 return x.y + body = self._get_live_annotated_fnbody(f5) + self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) + + def test_live_noop(self): + def f6(x): return x # should this cause x.* to be live? + body = self._get_live_annotated_fnbody(f6) + self._check_anno_matches(body[0], 'live_in', ('x')) + + def test_live_loop(self): + def f7(x, n): for i in range(n): x += i return x - def f8(x, f): - with f: - x += 1 - - body = get_live_annotated_fnbody(f1) - self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) - self._check_anno_matches(body[2], 'live_in', ('x')) - self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) - self._check_anno_matches(body[2], 'live_out', ()) - - body = get_live_annotated_fnbody(f2) - self._check_anno_matches(body[0], 'live_in', ('a')) - self._check_anno_matches(body[1], 'live_in', ('a')) - - body = get_live_annotated_fnbody(f3) - self._check_anno_matches(body[0], 'live_in', ('a', 'x')) - self._check_anno_matches(body[1], 'live_in', ('a', 'x')) - self._check_anno_matches(body[2], 'live_in', ('x')) - - body = get_live_annotated_fnbody(f4) - self._check_anno_matches(body[0], 'live_in', ('x', 'a')) + body = self._get_live_annotated_fnbody(f7) + self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) self._check_anno_matches(body[1], 'live_in', ('x')) - body = get_live_annotated_fnbody(f5) - self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) + def test_live_context_manager(self): - body = get_live_annotated_fnbody(f6) - self._check_anno_matches(body[0], 'live_in', ('x')) - - body = get_live_annotated_fnbody(f7) - self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) - self._check_anno_matches(body[1], 'live_in', ('x')) + def f8(x, f): + with f: + x += 1 - body = get_live_annotated_fnbody(f8) + body = self._get_live_annotated_fnbody(f8) self._check_anno_matches(body[0], 'live_in', ('f', 'x')) def test_node_equality(self): @@ -247,6 +260,47 @@ class CFGTest(test.TestCase): anno.getanno(body[2], 'defined_in'), frozenset(map(qual_names.QN, ('x', 'g')))) + def test_loop_else(self): + + # Disabling useless-else-on-loop error, because 'break' and 'continue' + # canonicalization are a separate analysis pass, and here we test + # the CFG analysis in isolation. + def for_orelse(x): + y = 0 + for i in range(len(x)): + x += i + else: # pylint: disable=useless-else-on-loop + y = 1 + return x, y + + def while_orelse(x, i): + y = 0 + while x < 10: + x += i + else: # pylint: disable=useless-else-on-loop + y = 1 + return x, y + + for f in (for_orelse, while_orelse): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) + body = node.body[0].body + return_node = body[-1] + reaching_defs = anno.getanno(return_node, 'definitions_in') + + # Y could be defined by Assign(Num(0)) or Assign(Num(1)) + # X could be defined as an argument or an AugAssign. + y_defs = [node for var, node in reaching_defs if str(var) == 'y'] + x_defs = [node for var, node in reaching_defs if str(var) == 'x'] + + self.assertEqual(set((gast.Assign,)), set(type(def_) for def_ in y_defs)) + self.assertEqual(set((0, 1)), set(def_.value.n for def_ in y_defs)) + self.assertEqual(len(y_defs), 2) + self.assertEqual( + set((gast.arguments, gast.AugAssign)), + set(type(def_) for def_ in x_defs)) + self.assertEqual(len(x_defs), 2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index c00946f9c41bc68d5c638d71f356b484db1286d1..7d1e65c958d7787ef5ed707d4822d14a83092975 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -17,8 +17,8 @@ This analyzer uses known live values to further infer object types. This may include for instance constructed objects and object member functions. -In addition, the analyzer will also process annotations for TF (staged) type -annotations. +In addition, the analyzer also handles user annotations made in the code (for +example, the autograph.set_element_type function). Requires annotations generated by LiveValuesResolver. """ @@ -44,6 +44,7 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -136,14 +137,14 @@ class TypeInfoResolver(transformer.Base): def _process_function_arg(self, arg_name): str_name = str(arg_name) + type_holder = arg_name.ast() + self.scope.setval(arg_name, type_holder) if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types: # Forge a node to hold the type information, so that method calls on # it can resolve the type. - type_holder = arg_name.ast() type_string, type_obj = self.context.arg_types[str_name] anno.setanno(type_holder, 'type', type_obj) anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) - self.scope.setval(arg_name, type_holder) def visit_arg(self, node): self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) @@ -159,58 +160,47 @@ class TypeInfoResolver(transformer.Base): # a = b # then for future references to `a` we should have definition = `b` definition = self.scope.getval(qn) - if anno.hasanno(definition, 'type'): - anno.setanno(node, 'type', anno.getanno(definition, 'type')) - anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn')) - if anno.hasanno(definition, 'element_type'): - anno.setanno(node, 'element_type', - anno.getanno(definition, 'element_type')) + anno.copyanno(definition, node, 'type') + anno.copyanno(definition, node, 'type_fqn') + anno.copyanno(definition, node, 'element_type') + anno.copyanno(definition, node, 'element_shape') return node - def _process_variable_assignment(self, source, targets): - # Special case: constructors. - if isinstance(source, gast.Call): - func = source.func + def _process_variable_assignment(self, target, value): + # Constructors + if isinstance(value, gast.Call): + func = value.func if anno.hasanno(func, 'live_val'): func_obj = anno.getanno(func, 'live_val') if tf_inspect.isclass(func_obj): - anno.setanno(source, 'is_constructor', True) - anno.setanno(source, 'type', func_obj) - anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) + anno.setanno(value, 'is_constructor', True) + anno.setanno(value, 'type', func_obj) + anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. - # Multiple targets mean multiple assignment. - for target in targets: - # Tuple target means unpacking. - if isinstance(target, (gast.Tuple, gast.List)): - for i, target_item in enumerate(target.elts): - # Two cases here: - # 1. Static unpacking, e.g. a, b = c, d - # 2. Dynamic unpacking, e.g. a, b = c - # The former case is optimized away. - if isinstance(source, (gast.Tuple, gast.List)): - source_item = source.elts[i] - else: - source_item = gast.Subscript(source, gast.Index(i), ctx=None) - self._process_variable_assignment(source_item, (target_item,)) - elif isinstance(target, (gast.Name, gast.Attribute)): - target_symbol = anno.getanno(target, anno.Basic.QN) - self.scope.setval(target_symbol, source) - else: - raise ValueError('assignment target has unknown type: %s' % target) + if isinstance(target, (gast.Name, gast.Attribute)): + target_symbol = anno.getanno(target, anno.Basic.QN) + self.scope.setval(target_symbol, value) + elif isinstance(target, gast.Subscript): + pass + else: + raise ValueError('assignment target has unknown type: %s' % target) def visit_With(self, node): - for wi in node.items: - if wi.optional_vars is not None: - self._process_variable_assignment(wi.context_expr, (wi.optional_vars,)) + for item in node.items: + if item.optional_vars is not None: + self.apply_to_single_assignments((item.optional_vars,), + item.context_expr, + self._process_variable_assignment) self.generic_visit(node) return node def visit_Assign(self, node): self.generic_visit(node) - self._process_variable_assignment(node.value, node.targets) + self.apply_to_single_assignments( + node.targets, node.value, self._process_variable_assignment) return node def visit_Call(self, node): @@ -220,23 +210,20 @@ class TypeInfoResolver(transformer.Base): if (anno.getanno(node.func, 'live_val') is self.context.type_annotation_func): - if len(node.args) != 2: - raise ValueError('"%s" must have exactly two parameters' + if len(node.args) < 2 or len(node.args) > 3: + raise ValueError('"%s" must have either two or three parameters' % self.context.type_annotation_func) - target_arg, type_arg = node.args + if len(node.args) == 2: + target_arg, type_arg = node.args + shape_arg = parser.parse_expression('None') + else: + target_arg, type_arg, shape_arg = node.args if not anno.hasanno(target_arg, anno.Basic.QN): raise ValueError('the first argument of "%s" must by a symbol' % self.context.type_annotation_func) - if isinstance(type_arg, gast.Str): - element_type = type_arg.s - elif isinstance(type_arg, gast.Num): - element_type = type_arg.n - else: - if not anno.hasanno(type_arg, 'live_val'): - raise ValueError( - 'the second argument of "%s" must be statically resolvable' % - self.context.type_annotation_func) - element_type = anno.getanno(type_arg, 'live_val') + # TODO(mdan): This is vulnerable to symbol renaming. + element_type = type_arg + element_shape = shape_arg target_symbol = anno.getanno(target_arg, anno.Basic.QN) # Find the definition of this symbol and annotate it with the given @@ -244,7 +231,9 @@ class TypeInfoResolver(transformer.Base): # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) + anno.setanno(node, 'element_shape', element_shape) anno.setanno(definition, 'element_type', element_type) + anno.setanno(definition, 'element_shape', element_shape) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 46b7701624a43073fb7cc612d2678ab851513d91..484562f294bb53a63feeca965b8f94c58aa2a685 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -187,14 +187,27 @@ class TypeInfoResolverTest(test.TestCase): def test_fn(): f = [] - f = utils.set_element_type(f, Foo) + f = utils.set_element_type(f, Foo, (1, 2, 3)) return f node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) f_def = node.body[0].body[0].value - self.assertEqual(anno.getanno(f_def, 'element_type'), Foo) + self.assertEqual(anno.getanno(f_def, 'element_type').id, 'Foo') f_ref = node.body[0].body[1].value - self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') + + def test_type_annotation_args(self): + + class Foo(object): + pass + + def test_fn(f): + utils.set_element_type(f, Foo) + return f + + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) + f_ref = node.body[0].body[1].value + self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') def test_nested_unpacking(self): @@ -210,9 +223,9 @@ class TypeInfoResolverTest(test.TestCase): node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar}) a, b, c = node.body[0].body[1].value.elts - self.assertEquals(Foo, anno.getanno(a, 'type')) - self.assertEquals(Bar, anno.getanno(b, 'type')) - self.assertEquals(Foo, anno.getanno(c, 'type')) + self.assertEquals(anno.getanno(a, 'type'), Foo) + self.assertEquals(anno.getanno(b, 'type'), Bar) + self.assertEquals(anno.getanno(c, 'type'), Foo) self.assertFalse(anno.hasanno(a, 'live_val')) self.assertFalse(anno.hasanno(b, 'live_val')) self.assertFalse(anno.hasanno(c, 'live_val')) @@ -229,8 +242,8 @@ class TypeInfoResolverTest(test.TestCase): node = self._parse_and_analyze(test_fn, {'utils': utils}) a, b = node.body[0].body[2].body[2].value.elts - self.assertEquals(1, anno.getanno(a, 'element_type')) - self.assertEquals(2, anno.getanno(b, 'element_type')) + self.assertEquals(anno.getanno(a, 'element_type').n, 1) + self.assertEquals(anno.getanno(b, 'element_type').n, 2) self.assertFalse(anno.hasanno(a, 'type')) self.assertFalse(anno.hasanno(b, 'type')) self.assertFalse(anno.hasanno(a, 'live_val')) diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py index baf7923fff7c786c1abd05e11fa6ffdb8c8f0912..9c479ebc2fa83d27dc363ae306daedb556734a1f 100644 --- a/tensorflow/contrib/autograph/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -239,8 +239,13 @@ def replace_as_expression(template, **replacements): raise ValueError( 'single expression expected; for more general templates use replace') node = replacement[0] - if not isinstance(node, gast.Expr): - raise ValueError( - 'the template is expected to generate an expression node; instead ' - 'found %s' % node) - return node.value + node = qual_names.resolve(node) + + if isinstance(node, gast.Expr): + return node.value + elif isinstance(node, gast.Name): + return node + + raise ValueError( + 'the template is expected to generate an expression or a name node;' + ' instead found %s' % node) diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index 4db6cc0adfad90ffc1a6bbcadfc80215688d271e..60bca8b38dcf62b4e997379d075cfc45511a894f 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -70,14 +70,40 @@ class Base(gast.NodeTransformer): return tuple(self._enclosing_entities) @property - def locel_scope_level(self): + def local_scope_level(self): return len(self._local_scope_state) - def enter_local_scope(self): - self._local_scope_state.append({}) + def enter_local_scope(self, inherit=None): + """Marks entry into a new local scope. - def exit_local_scope(self): - return self._local_scope_state.pop() + Args: + inherit: Optional enumerable of variable names to copy from the + parent scope. + """ + scope_entered = {} + if inherit: + this_scope = self._local_scope_state[-1] + for name in inherit: + if name in this_scope: + scope_entered[name] = this_scope[name] + self._local_scope_state.append(scope_entered) + + def exit_local_scope(self, keep=None): + """Marks exit from the current local scope. + + Args: + keep: Optional enumerable of variable names to copy into the + parent scope. + Returns: + A dict containing the scope that has just been exited. + """ + scope_left = self._local_scope_state.pop() + if keep: + this_scope = self._local_scope_state[-1] + for name in keep: + if name in scope_left: + this_scope[name] = scope_left[name] + return scope_left def set_local(self, name, value): self._local_scope_state[-1][name] = value @@ -91,38 +117,163 @@ class Base(gast.NodeTransformer): print(pretty_printer.fmt(node)) return node - def visit_block(self, nodes): - """Helper equivalent to generic_visit, but for node lists.""" + def visit_block(self, nodes, before_visit=None, after_visit=None): + """A more powerful version of generic_visit for statement blocks. + + An example of a block is the body of an if statement. + + This function allows specifying a postprocessing callback (the + after_visit argument) argument which can be used to move nodes to a new + destination. This is done by after_visit by returning a non-null + second return value, e.g. return new_node, new_destination. + + For example, a transformer could perform the following move: + + foo() + bar() + baz() + + foo() + if cond: + bar() + baz() + + The above could be done with a postprocessor of this kind: + + def after_visit(node): + if node_is_function_call(bar): + new_container_node = build_cond() + new_container_node.body.append(node) + return new_container_node, new_container_node.body + else: + # Once we set a new destination, all subsequent items will be + # moved to it, so we don't need to explicitly handle baz. + return node, None + + Args: + nodes: enumerable of AST node objects + before_visit: optional callable that is called before visiting each item + in nodes + after_visit: optional callable that takes in an AST node and + returns a tuple (new_node, new_destination). It is called after + visiting each item in nodes. Is used in the same was as the + visit_* methods: new_node will replace the node; if not None, + new_destination must be a list, and subsequent nodes will be placed + in this list instead of the list returned by visit_block. + Returns: + A list of AST node objects containing the transformed items fron nodes, + except those nodes that have been relocated using after_visit. + """ results = [] + node_destination = results for node in nodes: + if before_visit: + # TODO(mdan): We can modify node here too, if ever needed. + before_visit() + replacement = self.visit(node) + + if after_visit and replacement: + replacement, new_destination = after_visit(replacement) + else: + new_destination = None + if replacement: if isinstance(replacement, (list, tuple)): - results.extend(replacement) + node_destination.extend(replacement) else: - results.append(replacement) + node_destination.append(replacement) + + # Allow the postprocessor to reroute the remaining nodes to a new list. + if new_destination is not None: + node_destination = new_destination return results + # TODO(mdan): Once we have error tracing, we may be able to just go to SSA. + def apply_to_single_assignments(self, targets, values, apply_fn): + """Applies a fuction to each individual assignment. + + This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. + It tries to break down the unpacking if possible. In effect, it has the same + effect as passing the assigned values in SSA form to apply_fn. + + Examples: + + The following will result in apply_fn(a, c), apply_fn(b, d): + + a, b = c, d + + The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): + + a, b = c + + The following will result in apply_fn(a, (b, c)): + + a = b, c + + It uses the visitor pattern to allow subclasses to process single + assignments individually. + + Args: + targets: list, tuple of or individual AST node. Should be used with the + targets field of an ast.Assign node. + values: an AST node. + apply_fn: a function of a single argument, which will be called with the + respective nodes of each single assignment. The signaure is + apply_fn(target, value), no return value. + """ + if not isinstance(targets, (list, tuple)): + targets = (targets,) + for target in targets: + if isinstance(target, (gast.Tuple, gast.List)): + for i in range(len(target.elts)): + target_el = target.elts[i] + if isinstance(values, (gast.Tuple, gast.List)): + value_el = values.elts[i] + else: + value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store()) + self.apply_to_single_assignments(target_el, value_el, apply_fn) + else: + # TODO(mdan): Look into allowing to rewrite the AST here. + apply_fn(target, values) + def visit(self, node): source_code = self.context.source_code source_file = self.context.source_file did_enter_function = False - local_scope_state_size = len(self._local_scope_state) + local_scope_size_at_entry = len(self._local_scope_state) try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): - self._enclosing_entities.append(node) did_enter_function = True + if did_enter_function: + self._enclosing_entities.append(node) + if source_code and hasattr(node, 'lineno'): self._lineno = node.lineno self._col_offset = node.col_offset - if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): - return node - return super(Base, self).visit(node) - except (ValueError, AttributeError, KeyError, NotImplementedError, - AssertionError) as e: + if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + result = super(Base, self).visit(node) + + # On exception, the local scope integrity is not guaranteed. + if did_enter_function: + self._enclosing_entities.pop() + + if local_scope_size_at_entry != len(self._local_scope_state): + raise AssertionError( + 'Inconsistent local scope stack. Before entering node %s, the' + ' stack had length %d, after exit it has length %d. This' + ' indicates enter_local_scope and exit_local_scope are not' + ' well paired.' % ( + node, + local_scope_size_at_entry, + len(self._local_scope_state) + )) + return result + + except (ValueError, AttributeError, KeyError, NotImplementedError) as e: msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( e.__class__.__name__, str(e), try_ast_to_source(node), pretty_printer.fmt(node, color=False)) @@ -130,18 +281,11 @@ class Base(gast.NodeTransformer): line = source_code.splitlines()[self._lineno - 1] else: line = '' + # TODO(mdan): Avoid the printing of the original exception. + # In other words, we need to find how to suppress the "During handling + # of the above exception, another exception occurred" message. six.reraise(AutographParseError, AutographParseError( msg, (source_file, self._lineno, self._col_offset + 1, line)), sys.exc_info()[2]) - finally: - if did_enter_function: - self._enclosing_entities.pop() - - if local_scope_state_size != len(self._local_scope_state): - raise AssertionError( - 'Inconsistent local scope stack. Before entering node %s, the' - ' stack had length %d, after exit it has length %d. This' - ' indicates enter_local_scope and exit_local_scope are not' - ' well paired.') diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py index f96b0dc377521a482d347436caa98633a0a32c8a..f110e79605945e908e8a49112cf758ec29fa1b11 100644 --- a/tensorflow/contrib/autograph/pyct/transformer_test.py +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast + from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser @@ -27,7 +29,7 @@ from tensorflow.python.platform import test class TransformerTest(test.TestCase): - def _context_for_nodetesting(self): + def _context_for_testing(self): return context.EntityContext( namer=None, source_code=None, @@ -53,7 +55,7 @@ class TransformerTest(test.TestCase): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def test_function(): a = 0 @@ -94,7 +96,7 @@ class TransformerTest(test.TestCase): inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities')) - def test_statement_info_stack(self): + def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): @@ -116,7 +118,7 @@ class TransformerTest(test.TestCase): def visit_For(self, node): return self._annotate_result(node) - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def test_function(a): """Docstring.""" @@ -142,7 +144,7 @@ class TransformerTest(test.TestCase): self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('1', anno.getanno(while_node, 'test')) - def test_statement_info_stack_checks_integrity(self): + def test_local_scope_info_stack_checks_integrity(self): class TestTransformer(transformer.Base): @@ -155,7 +157,7 @@ class TransformerTest(test.TestCase): self.exit_local_scope() return node - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def no_exit(a): if a > 0: @@ -174,6 +176,38 @@ class TransformerTest(test.TestCase): with self.assertRaises(AssertionError): tr.visit(node) + def test_visit_block_postprocessing(self): + + class TestTransformer(transformer.Base): + + def _process_body_item(self, node): + if isinstance(node, gast.Assign) and (node.value.id == 'y'): + if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) + return if_node, if_node.body + return node, None + + def visit_FunctionDef(self, node): + node.body = self.visit_block( + node.body, after_visit=self._process_body_item) + return node + + def test_function(x, y): + z = x + z = y + return z + + tr = TestTransformer(self._context_for_testing()) + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + node = node.body[0] + + self.assertEqual(len(node.body), 2) + self.assertTrue(isinstance(node.body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1], gast.If)) + self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1].body[1], gast.Return)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD index d3a1b9468892531cbc51bc13de66ef595f1a95f8..d82c17bf2afd01aedf4344f983b02c09abcb9bad 100644 --- a/tensorflow/contrib/autograph/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -33,6 +33,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:dtypes", "//tensorflow/python:list_ops", "//tensorflow/python:script_ops", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index 211e8eaee9082dd3e4f035e4379871cd2e154a39..998087e056c2cd264399982220d6e0528aab9edb 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib.autograph.utils import py_func from tensorflow.contrib.autograph.utils import type_check +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops @@ -38,7 +39,13 @@ def dynamic_builtin(f, *args, **kwargs): return dynamic_range(*args, **kwargs) if f is range: return dynamic_range(*args, **kwargs) - raise ValueError('%s is not supported' % f) + if f is int: + return dynamic_int(*args, **kwargs) + if f is float: + return dynamic_float(*args, **kwargs) + + raise NotImplementedError( + 'The "%s" builtin is not yet supported.' % f.__name__) def dynamic_len(list_or_tensor): @@ -52,6 +59,20 @@ def dynamic_len(list_or_tensor): return len(list_or_tensor) +def dynamic_int(num_or_tensor, **kwargs): + """Implementation of int() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs) + return int(num_or_tensor) + + +def dynamic_float(num_or_tensor, **kwargs): + """Implementation of float() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs) + return float(num_or_tensor) + + def dynamic_range(start_or_stop, stop=None, step=None): """Implementation of range using dynamic dispatch.""" if type_check.is_tensor(start_or_stop, stop, step): diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py index 163e6984079fea5c3b3d9aeda0ec8048d651686f..0c2312178a921037fa419818bf309d671c33914d 100644 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib.autograph.utils import builtins from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.platform import test @@ -77,7 +78,7 @@ class BuiltinsTest(test.TestCase): return x # Functions that just have the names of builtins are rejected. - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): self.assertEqual(builtins.dynamic_builtin(range, 1), 1) if six.PY2: self.assertListEqual( @@ -87,6 +88,20 @@ class BuiltinsTest(test.TestCase): self.assertListEqual( list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) + def test_casts(self): + i = constant_op.constant(2, dtype=dtypes.int32) + f = constant_op.constant(1.0, dtype=dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) + self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, True), 1) + self.assertEqual(builtins.dynamic_builtin(int, False), 0) + self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) + self.assertEqual(builtins.dynamic_builtin(float, False), 0.0) + def test_dynamic_print_tf(self): try: out_capturer = six.StringIO() diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index d65c990c87cbc316472237d183c03765416501e7..b27a19b16c08cb588b45949105a6399623e766e1 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -49,6 +49,14 @@ cc_library( ], ) +cc_library( + name = "serial_device_batch_scheduler", + hdrs = ["serial_device_batch_scheduler.h"], + deps = [ + "//tensorflow/core/kernels/batching_util:serial_device_batch_scheduler", + ], +) + cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], @@ -96,6 +104,7 @@ py_test( name = "batch_ops_test", size = "small", srcs = ["python/ops/batch_ops_test.py"], + shard_count = 5, srcs_version = "PY2AND3", tags = [ "manual", diff --git a/tensorflow/contrib/batching/__init__.py b/tensorflow/contrib/batching/__init__.py index 44fa5f42a73bfb1bf008f6f4eafd14913c88dcfa..1e503a097a7b72d9244b0a1cf57747c4b4122c81 100644 --- a/tensorflow/contrib/batching/__init__.py +++ b/tensorflow/contrib/batching/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Ops and modules related to batch. +@@batch_function_v1 @@batch_function """ from __future__ import absolute_import diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 921d6917a4e478c3e60771fdc3ae99febc33d2e3..012a51f71101471850d312033c41dcbc4805d44c 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import @@ -83,6 +84,74 @@ def batch_function(num_batch_threads, SparseTensor is not supported. The return value of the decorated function must be a Tensor or a list/tuple of Tensors. + Args: + num_batch_threads: Number of scheduling threads for processing batches + of work. Determines the number of batches processed in parallel. + max_batch_size: Batch sizes will never be bigger than this. + batch_timeout_micros: Maximum number of microseconds to wait before + outputting an incomplete batch. + allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, + does nothing. Otherwise, supplies a list of batch sizes, causing the op + to pad batches up to one of those sizes. The entries must increase + monotonically, and the final entry must equal max_batch_size. + grad_timeout_micros: The timeout to use for the gradient. See the + documentation of the unbatch op for more details. Defaults to 60s. + unbatch_timeout_micros: The timeout to use for unbatching. See the + documentation of the unbatch op for more details. Defaults to 60s. + max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. + + Returns: + The decorated function will return the unbatched computation output Tensors. + """ + + def decorator(fn): # pylint: disable=missing-docstring + + def decorated(*args): # pylint: disable=missing-docstring + types = [arg.dtype for arg in args] + + @function.Defun(*types) + def computation(*computation_args): + return fn(*computation_args) + + with ops.name_scope("batch") as name: + for a in args: + if not isinstance(a, ops.Tensor): + raise ValueError("All arguments to functions decorated with " + "`batch_function` are supposed to be Tensors; " + "found %s" % repr(a)) + for inp in computation.captured_inputs: + print("inp: %s" % inp) + for op in inp.consumers(): + print("op: %s" % op) + return gen_batch_ops.batch_function( + num_batch_threads=num_batch_threads, + max_batch_size=max_batch_size, + batch_timeout_micros=batch_timeout_micros, + allowed_batch_sizes=allowed_batch_sizes, + max_enqueued_batches=max_enqueued_batches, + shared_name=name, + f=computation, + in_tensors=list(args), + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + return decorated + + return decorator + + +def batch_function_v1(num_batch_threads, + max_batch_size, + batch_timeout_micros, + allowed_batch_sizes=None, + grad_timeout_micros=60 * 1000 * 1000, + unbatch_timeout_micros=60 * 1000 * 1000, + max_enqueued_batches=10): + """Batches the computation done by the decorated function. + + This is the older version of batch_function(). Please use the former instead + of this. + Args: num_batch_threads: Number of scheduling threads for processing batches of work. Determines the number of batches processed in parallel. diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index e22f978dde6f1b7febc771d526201579c20292c7..78468145469df216344bc00f116add250dc51dd3 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -23,7 +23,10 @@ import time from tensorflow.contrib.batching.python.ops import batch_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework.errors import InvalidArgumentError from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -185,12 +188,62 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBasicUnbatchV1Decorated(self): + """Tests that the batch_function_v1 decorator works.""" + with self.test_session() as sess: + @batch_ops.batch_function_v1(1, 10, 100000) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testBasicUnbatchDecorated(self): """Tests that the batch_function decorator works.""" with self.test_session() as sess: + # TODO(apassos): Removing this line causes test flakiness! Ideally should + # be investigated. + default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable + @batch_ops.batch_function(1, 10, 100000) def computation(in_t): return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchDecoratedWithCapturedInput(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return in_t + captured_inp0 - captured_inp1 + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) result = computation(inp) thread_results = [] @@ -205,6 +258,114 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBatchFunctionOp(self): + """Tests that the batch_function op works.""" + with self.test_session() as sess: + + @function.Defun(dtypes.int32) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = gen_batch_ops.batch_function( + [inp], + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, + Tout=[dtypes.int32], + f=computation, + captured_tensors=computation.captured_inputs) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOpWithCapturedInput(self): + """Tests that batch_function op works with captured input.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32) + def computation(inp): + return inp + captured_inp0 - captured_inp1 + + result = gen_batch_ops.batch_function( + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + allowed_batch_sizes=[3, 10], + batching_queue="", + f=computation, + in_tensors=[inp], + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOpWithInputError(self): + """Tests that batch_function op works with error in the inputs.""" + with self.test_session() as sess: + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32, dtypes.int32) + def computation(in0, in1): + return in0 + in1 + + result = gen_batch_ops.batch_function( + [inp], # computation actually expects 2 inputs. + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + batching_queue="", + f=computation, + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + with self.assertRaisesRegexp(InvalidArgumentError, + ".*2 arguments.*but 1.*"): + sess.run([result], feed_dict={inp: [2]}) + + def testBasicUnbatchDecoratedWithReshape(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return array_ops.reshape(in_t, [-1]) + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [[1]]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [[2]]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testUnbatchTimeout(self): """Tests that the unbatch timeout works.""" with self.test_session() as sess: diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.cc b/tensorflow/contrib/batching/serial_device_batch_scheduler.h similarity index 59% rename from tensorflow/compiler/xla/service/versioned_computation_handle.cc rename to tensorflow/contrib/batching/serial_device_batch_scheduler.h index a693c4695f0e776cf297d0ecd28d6de53bd5c0c6..bf6b7083612018eecf0d1784e60cbbf0c5796fef 100644 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.cc +++ b/tensorflow/contrib/batching/serial_device_batch_scheduler.h @@ -13,20 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#ifndef TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ -#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h" -namespace xla { - -string VersionedComputationHandle::ToString() const { - return tensorflow::strings::StrCat(handle.handle(), ":v", version); -} - -std::ostream& operator<<(std::ostream& out, - const VersionedComputationHandle& versioned_handle) { - out << versioned_handle.ToString(); - return out; -} - -} // namespace xla +#endif // TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py index 5770bcdd706723394bb06196d24aeb32b8b8491a..68fa415eeaf1d1ae7c2ecf1be1c300eddbfa4e69 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monte Carlo integration and helpers. - -See the @{$python/contrib.bayesflow.monte_carlo} guide. -""" +"""Monte Carlo integration and helpers.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 9994c84ebdb930eea0818188225488eb5eca84eb..911d87fa10570382ee5f03edfc1bfd1d116c8360 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -45,6 +45,7 @@ from tensorflow.python.training import training_util _DNN_LEARNING_RATE = 0.001 + def _get_optimizer(optimizer): if callable(optimizer): return optimizer() @@ -73,6 +74,7 @@ def _dnn_tree_combined_model_fn(features, dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -108,6 +110,8 @@ def _dnn_tree_combined_model_fn(features, as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -132,8 +136,7 @@ def _dnn_tree_combined_model_fn(features, dnn_parent_scope = "dnn" dnn_partitioner = dnn_input_layer_partitioner or ( partitioned_variables.min_max_variable_partitioner( - max_partitions=config.num_ps_replicas, - min_slice_size=64 << 20)) + max_partitions=config.num_ps_replicas, min_slice_size=64 << 20)) with variable_scope.variable_scope( dnn_parent_scope, @@ -171,8 +174,7 @@ def _dnn_tree_combined_model_fn(features, _add_hidden_layer_summary(net, hidden_layer_scope.name) previous_layer = net with variable_scope.variable_scope( - "logits", - values=(previous_layer,)) as logits_scope: + "logits", values=(previous_layer,)) as logits_scope: dnn_logits = layers.fully_connected( previous_layer, head.logits_dimension, @@ -190,8 +192,7 @@ def _dnn_tree_combined_model_fn(features, optimizer=_get_optimizer(dnn_optimizer), name=dnn_parent_scope, variables=ops.get_collection( - ops.GraphKeys.TRAINABLE_VARIABLES, - scope=dnn_parent_scope), + ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope), # Empty summaries to prevent optimizers from logging training_loss. summaries=[]) @@ -230,7 +231,16 @@ def _dnn_tree_combined_model_fn(features, update_op = state_ops.assign_add(global_step, 1).op return update_op - tree_train_logits = dnn_logits + tree_logits + if predict_with_tree_only: + if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.PREDICT: + tree_train_logits = tree_logits + else: + tree_train_logits = control_flow_ops.cond( + global_step > dnn_steps_to_train, + lambda: tree_logits, + lambda: dnn_logits) + else: + tree_train_logits = dnn_logits + tree_logits def _no_train_op_fn(loss): """Returns a no-op.""" @@ -288,10 +298,10 @@ def _dnn_tree_combined_model_fn(features, finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() model_fn_ops.training_hooks.extend([ - trainer_hooks.SwitchTrainOp( - dnn_train_op, dnn_steps_to_train, tree_train_op), - trainer_hooks.StopAfterNTrees( - num_trees, attempted_trees, finalized_trees)]) + trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, + tree_train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees) + ]) return model_fn_ops @@ -318,6 +328,7 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -360,6 +371,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -377,16 +390,32 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedClassifier, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) class DNNBoostedTreeCombinedRegressor(estimator.Estimator): @@ -410,6 +439,7 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -452,6 +482,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -474,16 +506,32 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedRegressor, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) class DNNBoostedTreeCombinedEstimator(estimator.Estimator): @@ -508,6 +556,7 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -545,6 +594,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -553,15 +604,32 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. """ + def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 89d0d611d2905492cec09e033b8cbc238ec7fac6..9c36c302210185bc390751a0229a61f2f8cd91b8 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -41,7 +41,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -66,6 +67,16 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is + [batch_size, num_trees]. + For example, + result_iter = classifier.predict(...) + for result_dict in result_iter: + # access leaf index list by result_dict["leaf_index"] + # which contains one leaf index per tree + Raises: ValueError: If learner_config is not valid. """ @@ -74,7 +85,9 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): # supports second order derivative. def loss_fn(labels, logits, weights=None): result = losses.per_example_maxent_loss( - labels=labels, logits=logits, weights=weights, + labels=labels, + logits=logits, + weights=weights, num_classes=n_classes) return math_ops.reduce_mean(result[0]) else: @@ -102,6 +115,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'center_bias': center_bias, 'logits_modifier_function': logits_modifier_function, 'use_core_libs': use_core_libs, + 'output_leaf_index': output_leaf_index, }, model_dir=model_dir, config=config, @@ -124,7 +138,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -151,6 +166,13 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree """ head = head_lib.regression_head( label_name=label_name, @@ -173,6 +195,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, 'use_core_libs': use_core_libs, + 'output_leaf_index': False, }, model_dir=model_dir, config=config, @@ -197,7 +220,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -220,6 +244,13 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -233,6 +264,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, 'use_core_libs': use_core_libs, + 'output_leaf_index': False, }, model_dir=model_dir, config=config, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 0d58317bd59331cfcde0e12aeb3a3a03fc45d89b..75ef1b050028b6462b255827c06e836e5c481844 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -68,6 +68,28 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): classifier.evaluate(input_fn=_eval_input_fn, steps=1) classifier.export(self._export_dir_base) + def testThatLeafIndexIsInPredictions(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=True) + + classifier.fit(input_fn=_train_input_fn, steps=15) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("leaf_index" in prediction_dict) + self.assertTrue("logits" in prediction_dict) + def testFitAndEvaluateDontThrowExceptionWithCoreForEstimator(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 15ab6d814522ab1dee58dcd71246354fc4d8a483..1ee891198939e53fc5913104b2c2e65dc977823f 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -63,6 +63,8 @@ def model_builder(features, labels, mode, params, config): num_trees = params["num_trees"] use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] + output_leaf_index = params["output_leaf_index"] + if features is None: raise ValueError("At least one feature must be specified.") @@ -96,7 +98,8 @@ def model_builder(features, labels, mode, params, config): feature_columns=feature_columns, logits_dimension=head.logits_dimension, features=training_features, - use_core_columns=use_core_libs) + use_core_columns=use_core_libs, + output_leaf_index=output_leaf_index) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] @@ -127,6 +130,9 @@ def model_builder(features, labels, mode, params, config): labels=labels, train_op_fn=_train_op_fn, logits=logits) + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] if num_trees: if center_bias: num_trees += 1 diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index b3fe38614e05801b223f0c96f7a70ce7e432a70b..9493c1a1394040db3b744f1b382b20bd5bd1988d 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -59,6 +59,7 @@ const char* kApplyDropoutAttributeName = "apply_dropout"; const char* kApplyAveragingAttributeName = "apply_averaging"; const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights"; const char* kPredictionsTensorName = "predictions"; +const char* kLeafIndexTensorName = "leaf_index"; void CalculateTreesToInclude( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, @@ -170,15 +171,22 @@ class GradientTreesPredictionOp : public OpKernel { core::ScopedUnref unref_me(ensemble_resource); if (use_locking_) { tf_shared_lock l(*ensemble_resource->get_mutex()); - DoCompute(context, ensemble_resource); + DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/false); } else { - DoCompute(context, ensemble_resource); + DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/false); } } - private: - void DoCompute(OpKernelContext* context, - DecisionTreeEnsembleResource* ensemble_resource) { + protected: + // return_output_leaf_index is a boolean variable indicating whether to output + // leaf index in prediction. Though this class invokes only with this param + // value as false, the subclass GradientTreesPredictionVerboseOp will invoke + // with the true value. + virtual void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource, + const bool return_output_leaf_index) { // Read dense float features list; OpInputList dense_float_features_list; OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures( @@ -267,6 +275,14 @@ class GradientTreesPredictionOp : public OpKernel { &output_predictions_t)); auto output_predictions = output_predictions_t->matrix(); + // Allocate output leaf index matrix. + Tensor* output_leaf_index_t = nullptr; + if (return_output_leaf_index) { + OP_REQUIRES_OK(context, context->allocate_output( + kLeafIndexTensorName, + {batch_size, ensemble_resource->num_trees()}, + &output_leaf_index_t)); + } // Run predictor. thread::ThreadPool* const worker_threads = context->device()->tensorflow_cpu_worker_threads()->workers; @@ -288,11 +304,13 @@ class GradientTreesPredictionOp : public OpKernel { i, weight * (num_ensembles - i + start_averaging) / num_ensembles); } MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features, - worker_threads, output_predictions); + worker_threads, output_predictions, + output_leaf_index_t); } else { MultipleAdditiveTrees::Predict( ensemble_resource->decision_tree_ensemble(), trees_to_include, - batch_features, worker_threads, output_predictions); + batch_features, worker_threads, output_predictions, + output_leaf_index_t); } // Output dropped trees and original weights. @@ -302,7 +320,6 @@ class GradientTreesPredictionOp : public OpKernel { {2, static_cast(dropped_trees.size())}, &output_dropout_info_t)); auto output_dropout_info = output_dropout_info_t->matrix(); - for (int32 i = 0; i < dropped_trees.size(); ++i) { output_dropout_info(0, i) = dropped_trees[i]; output_dropout_info(1, i) = original_weights[i]; @@ -326,6 +343,27 @@ class GradientTreesPredictionOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("GradientTreesPrediction").Device(DEVICE_CPU), GradientTreesPredictionOp); +// GradientTreesPredictionVerboseOp is derived from GradientTreesPredictionOp +// and have an additional output of tensor of rank 2 containing leaf ids for +// each tree where an instance ended up with. +class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp { + public: + explicit GradientTreesPredictionVerboseOp(OpKernelConstruction* const context) + : GradientTreesPredictionOp(context) {} + + protected: + void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource, + bool return_output_leaf_index) override { + GradientTreesPredictionOp::DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/true); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("GradientTreesPredictionVerbose").Device(DEVICE_CPU), + GradientTreesPredictionVerboseOp); + class GradientTreesPartitionExamplesOp : public OpKernel { public: explicit GradientTreesPartitionExamplesOp(OpKernelConstruction* const context) diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 44a8ffaf4b2f5a9c11b3abc46ce55a18c80ad318..401bec84a20a0fefcddbfa1039a117e65f853633 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -43,47 +43,60 @@ namespace { const int32 DUMMY_FEATURE_DIMENSION = -1; } // namespace -class BaseBuildSplitOp : public OpKernel { +class SplitBuilderState { public: - explicit BaseBuildSplitOp(OpKernelConstruction* const context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id", - &feature_column_group_id_)); + explicit SplitBuilderState(OpKernelContext* const context) { + const Tensor* l1_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l1_regularization", &l1_regularization_)); + context->input("l1_regularization", &l1_regularization_t)); + const Tensor* l2_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l2_regularization", &l2_regularization_)); - OP_REQUIRES_OK(context, context->GetAttr("tree_complexity_regularization", - &tree_complexity_regularization_)); + context->input("l2_regularization", &l2_regularization_t)); + const Tensor* tree_complexity_regularization_t; + OP_REQUIRES_OK(context, context->input("tree_complexity_regularization", + &tree_complexity_regularization_t)); + const Tensor* min_node_weight_t; OP_REQUIRES_OK(context, - context->GetAttr("min_node_weight", &min_node_weight_)); + context->input("min_node_weight", &min_node_weight_t)); - int strategy; - OP_REQUIRES_OK(context, context->GetAttr("multiclass_strategy", &strategy)); + const Tensor* feature_column_group_id_t; + OP_REQUIRES_OK(context, context->input("feature_column_group_id", + &feature_column_group_id_t)); + + const Tensor* multiclass_strategy_t; + OP_REQUIRES_OK( + context, context->input("multiclass_strategy", &multiclass_strategy_t)); + int strategy = multiclass_strategy_t->scalar()(); OP_REQUIRES( context, boosted_trees::learner::LearnerConfig_MultiClassStrategy_IsValid( strategy), errors::InvalidArgument("Wrong multiclass strategy passed.")); - multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - } - NodeStats ComputeNodeStats(const GradientStats& grad_stats) { - return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, - multiclass_strategy_, grad_stats); - } + multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - void ReadClassId(OpKernelContext* const context, int32* class_id) { const Tensor* class_id_t; OP_REQUIRES_OK(context, context->input("class_id", &class_id_t)); OP_REQUIRES(context, TensorShapeUtils::IsScalar(class_id_t->shape()), errors::InvalidArgument("class_id must be a scalar.")); - *class_id = class_id_t->scalar()(); + class_id_ = class_id_t->scalar()(); + + l1_regularization_ = l1_regularization_t->scalar()(); + l2_regularization_ = l2_regularization_t->scalar()(); + tree_complexity_regularization_ = + tree_complexity_regularization_t->scalar()(); + min_node_weight_ = min_node_weight_t->scalar()(); + feature_column_group_id_ = feature_column_group_id_t->scalar()(); } - void FillLeaf(const int class_id, const NodeStats& best_node_stats, + NodeStats ComputeNodeStats(const GradientStats& grad_stats) { + return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, + multiclass_strategy_, grad_stats); + } + + void FillLeaf(const NodeStats& best_node_stats, boosted_trees::trees::Leaf* leaf) const { - if (class_id == -1) { + if (class_id_ == -1) { // This would be the case either for TREE_PER_CLASS with only 2 classes, // or for other multiclass strategies. for (float f : best_node_stats.weight_contribution) { @@ -93,25 +106,31 @@ class BaseBuildSplitOp : public OpKernel { CHECK(best_node_stats.weight_contribution.size() == 1) << "Weight contribution size = " << best_node_stats.weight_contribution.size(); - leaf->mutable_sparse_vector()->add_index(class_id); + leaf->mutable_sparse_vector()->add_index(class_id_); leaf->mutable_sparse_vector()->add_value( best_node_stats.weight_contribution[0]); } } - protected: + int32 feature_column_group_id() { return feature_column_group_id_; } + float tree_complexity_regularization() { + return tree_complexity_regularization_; + } + + private: LearnerConfig_MultiClassStrategy multiclass_strategy_; - int32 feature_column_group_id_; float l1_regularization_; float l2_regularization_; - float min_node_weight_; float tree_complexity_regularization_; + float min_node_weight_; + int32 class_id_; + int32 feature_column_group_id_; }; -class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildDenseInequalitySplitsOp : public OpKernel { public: explicit BuildDenseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) {} + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -139,9 +158,6 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); - // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; partition_boundaries.push_back(0); @@ -185,6 +201,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[root_idx]; @@ -196,7 +213,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats(*gradients_t, *hessians_t, bucket_idx); } root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_bucket_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -206,10 +223,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats g(*gradients_t, *hessians_t, bucket_idx); g *= normalizer_ratio; left_gradient_stats += g; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -220,18 +237,18 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* dense_split = split_info.mutable_split_node()->mutable_dense_float_binary_split(); - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); dense_split->set_threshold( bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } @@ -239,13 +256,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU), BuildDenseInequalitySplitsOp); -class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildSparseInequalitySplitsOp : public OpKernel { public: explicit BuildSparseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -275,8 +289,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // For each partition (tree node), store starting index for each dimension. PartitionAndDimensionBoundaries partition_boundaries; @@ -354,6 +370,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); // For each tree node that needs to be split. for (int root_idx = 0; root_idx < num_elements; ++root_idx) { const auto& dimension_boundaries = @@ -372,7 +389,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { OP_REQUIRES( context, - bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_, + bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); // Dimension for bias feature is always 0 @@ -388,7 +405,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats root_gradient_stats(*gradients_t, *hessians_t, bias_start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); // Iterate through dimensions. for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { @@ -408,7 +425,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { << bucket_ids_and_dimensions(start_index, 1) << " and for " << bucket_ids_and_dimensions(end_index - 1, 0) << " " << bucket_ids_and_dimensions(end_index - 1, 1); - if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) { + if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) { // 0-dimension case which has a first bucket for catch all feature. CHECK(bucket_ids_and_dimensions(start_index, 1) == 0) << "Dimension of bias feature should be 0"; @@ -422,6 +439,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats(*gradients_t, *hessians_t, bucket_idx); } present_gradient_stats *= normalizer_ratio; + GradientStats not_present = + root_gradient_stats - present_gradient_stats; + // If there was (almost) no sparsity, fix the default direction to LEFT. + bool fixed_default_direction = not_present.IsAlmostZero(); GradientStats left_gradient_stats; for (int64 element_idx = start_index; element_idx < end_index; @@ -441,11 +462,12 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // backward pass gradients. GradientStats right_gradient_stats = present_gradient_stats - left_gradient_stats; + { - NodeStats left_stats_default_left = - ComputeNodeStats(root_gradient_stats - right_gradient_stats); + NodeStats left_stats_default_left = state.ComputeNodeStats( + root_gradient_stats - right_gradient_stats); NodeStats right_stats_default_left = - ComputeNodeStats(right_gradient_stats); + state.ComputeNodeStats(right_gradient_stats); if (left_stats_default_left.gain + right_stats_default_left.gain > best_gain) { best_gain = @@ -457,11 +479,13 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { best_dimension_idx = dimension_id; } } - { + // Consider calculating the default direction only when there were + // enough missing examples. + if (!fixed_default_direction) { NodeStats left_stats_default_right = - ComputeNodeStats(left_gradient_stats); - NodeStats right_stats_default_right = - ComputeNodeStats(root_gradient_stats - left_gradient_stats); + state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats_default_right = state.ComputeNodeStats( + root_gradient_stats - left_gradient_stats); if (left_stats_default_right.gain + right_stats_default_right.gain > best_gain) { best_gain = left_stats_default_right.gain + @@ -487,7 +511,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { ->mutable_sparse_float_binary_split_default_left() ->mutable_split(); } - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); // Set the feature index for the best feature column. const int64 best_dimension_id = bucket_ids_and_dimensions(best_element_idx, 1); @@ -498,11 +522,11 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(bias_start_index); } } @@ -519,19 +543,14 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // For each partition, store start indices of feature column dimensions. typedef std::vector> PartitionAndDimensionBoundaries; - - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU), BuildSparseInequalitySplitsOp); -class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { +class BuildCategoricalEqualitySplitsOp : public OpKernel { public: explicit BuildCategoricalEqualitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -554,8 +573,10 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; @@ -598,16 +619,17 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[non_empty_partitions[root_idx]]; int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; // First feature ID in each partition should be the bias feature. - OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_, + OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_feature_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -618,8 +640,8 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { left_gradient_stats *= normalizer_ratio; GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -630,21 +652,18 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); - equality_split->set_feature_column(feature_column_group_id_); + equality_split->set_feature_column(state.feature_column_group_id()); equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } - - private: - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index f06b73c00d0bebb2717a79b7894e2addf914daba..409a2d8f46c331c13aec10542c4967d50575e94a 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -64,6 +64,8 @@ from __future__ import print_function import re from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler +from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops +from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.contrib.boosted_trees.python.ops import quantile_ops from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops @@ -72,9 +74,11 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops + _BIAS_FEATURE_ID = -1 # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") @@ -130,11 +134,14 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): gradient_shape, hessian_shape, name="StatsAccumulator/{}".format(self._name)) - self._quantile_accumulator = quantile_ops.QuantileAccumulator( - init_stamp_token, - epsilon=epsilon, - num_quantiles=num_quantiles, - name="QuantileAccumulator/{}".format(self._name)) + # Allocate both stats accumulator and quantile accumulator on the same + # device so that we can build splits with fewer RPCs. + with ops.colocate_with(self._stats_accumulator.resource()): + self._quantile_accumulator = quantile_ops.QuantileAccumulator( + init_stamp_token, + epsilon=epsilon, + num_quantiles=num_quantiles, + name="QuantileAccumulator/{}".format(self._name)) class DenseSplitHandler(InequalitySplitHandler): @@ -236,45 +243,74 @@ class DenseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - # Get the aggregated gradients and hessians per - # pair. - # In order to distribute the computation on all the PSs we use the PS that - # had the stats accumulator on. - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - - partition_ids, gains, split_infos = ( - split_handler_ops.build_dense_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_dense_split_scalar + else: + handler = make_dense_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional): + """Function that builds splits for a dense feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_dense_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos class SparseSplitHandler(InequalitySplitHandler): @@ -327,9 +363,6 @@ class SparseSplitHandler(InequalitySplitHandler): multiclass_strategy=multiclass_strategy, init_stamp_token=init_stamp_token, name=name) - # Register sparse_make_stats_update function as an Op to the graph. - g = ops.get_default_graph() - sparse_make_stats_update.add_to_graph(g) self._sparse_float_column = sparse_float_column def scheduled_reads(self): @@ -361,8 +394,8 @@ class SparseSplitHandler(InequalitySplitHandler): are_buckets_ready, buckets = scheduled_reads[0] with ops.name_scope(self._name, "SparseSplitHandler"): (quantile_indices, quantile_values, quantile_shapes, quantile_weights, - example_partition_ids, - feature_ids, gradients, hessians) = sparse_make_stats_update( + example_partition_ids, feature_ids, gradients, + hessians) = sparse_make_stats_update( is_active, are_buckets_ready, self._sparse_float_column.indices, self._sparse_float_column.values, self._sparse_float_column.dense_shape, buckets, @@ -379,42 +412,115 @@ class SparseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) - - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - partition_ids, gains, split_infos = ( - split_handler_ops.build_sparse_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - bias_feature_id=_BIAS_FEATURE_ID, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_sparse_split_scalar + else: + handler = make_sparse_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional): + """Function that builds splits for a sparse feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + bias_feature_id=_BIAS_FEATURE_ID, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _specialize_make_split(func, is_multi_dimentional): + """Builds a specialized version of the function.""" + + @function.Defun( + dtypes.resource, + dtypes.resource, + dtypes.int64, + dtypes.int64, + dtypes.int32, + dtypes.int32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + noinline=True) + def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight): + """Function that builds splits for a sparse feature column.""" + return func( + quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional) + + return f + +make_dense_split_scalar = _specialize_make_split(_make_dense_split, + is_multi_dimentional=False) +make_dense_split_tensor = _specialize_make_split(_make_dense_split, + is_multi_dimentional=True) + +make_sparse_split_scalar = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=False) +make_sparse_split_tensor = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=True) @function.Defun( @@ -540,8 +646,9 @@ def sparse_make_stats_update( empty_float = constant_op.constant([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( - [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant( - [0, 1], dtype=dtypes.int64), empty_float) + [], dtype=dtypes.int64, shape=[0, 2]), empty_float, + constant_op.constant([0, 1], dtype=dtypes.int64), + empty_float) handler_active = (sparse_column_indices, sparse_column_values, sparse_column_shape, weights) quantile_indices, quantile_values, quantile_shape, quantile_weights = ( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 54d03018d9e266beabbbabd78ebbb80cfe689c04..2f2c2302113bf59d6a065d5005c934dc76c2148d 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import split_info_pb2 @@ -65,9 +67,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.scalar() split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -92,7 +94,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -105,7 +109,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -199,10 +203,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.TensorShape([2, 2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -227,7 +231,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -240,7 +246,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -285,10 +291,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.TensorShape([2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -313,7 +319,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -326,7 +333,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -369,9 +376,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -396,7 +403,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -409,7 +417,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -443,9 +451,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, - min_node_weight=0, + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -470,7 +478,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -483,7 +492,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -576,7 +585,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, min_node_weight=1.5, epsilon=0.001, @@ -603,7 +612,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -616,7 +626,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -685,10 +695,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -713,8 +723,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] - + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -727,7 +737,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -811,10 +821,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -839,7 +849,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -853,7 +864,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -905,10 +916,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -933,7 +944,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -947,7 +959,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -996,10 +1008,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1024,7 +1036,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1038,7 +1051,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1065,10 +1078,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1096,7 +1109,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1110,7 +1124,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1138,10 +1152,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1166,7 +1180,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1180,7 +1195,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc index 43b00d4c6dc2e0066810012292874314215c41be..c9223afeab233497bce9f680bd44bd10ccfc6491 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc @@ -26,7 +26,8 @@ void MultipleAdditiveTrees::Predict( const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, tensorflow::thread::ThreadPool* const worker_threads, - tensorflow::TTypes::Matrix output_predictions) { + tensorflow::TTypes::Matrix output_predictions, + Tensor* const output_leaf_index) { // Zero out predictions as the model is additive. output_predictions.setZero(); @@ -38,8 +39,13 @@ void MultipleAdditiveTrees::Predict( // Lambda for doing a block of work. auto update_predictions = [&config, &features, &trees_to_include, - &output_predictions](int64 start, int64 end) { + &output_predictions, + &output_leaf_index](int64 start, int64 end) { auto examples_iterable = features.examples_iterable(start, end); + Tensor dummy_tensor(DT_INT32, TensorShape({1, 1})); + tensorflow::TTypes::Matrix output_leaf_index_mat = + output_leaf_index != nullptr ? output_leaf_index->matrix() + : dummy_tensor.matrix(); for (const auto& example : examples_iterable) { for (const int32 tree_idx : trees_to_include) { const boosted_trees::trees::DecisionTreeConfig& tree = @@ -47,6 +53,10 @@ void MultipleAdditiveTrees::Predict( const float tree_weight = config.tree_weights(tree_idx); const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); + // Checks if output leaf tree index is required. + if (output_leaf_index != nullptr) { + output_leaf_index_mat(example.example_idx, tree_idx) = leaf_idx; + } const auto& leaf_node = tree.nodes(leaf_idx); QCHECK(leaf_node.has_leaf()) << "Invalid leaf node: " << leaf_node.DebugString(); diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h index cc3dc226cdbc88fc7010ada1e7f0e6c0a3913c5f..940531c4ba4bcac19fa980deb091e55b48e0693b 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h @@ -33,12 +33,17 @@ class MultipleAdditiveTrees { public: // Predict runs tree ensemble on the given batch and updates // output predictions accordingly, for the given list of trees. + // output_leaf_indices is a pointer to a 2 dimensional tensor. If it is not + // nullptr, this method fills output_leaf_indices with a per-tree leaf id + // where each of the instances from 'features' ended up in. Its shape is num + // examples X num of trees. static void Predict( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, tensorflow::thread::ThreadPool* const worker_threads, - tensorflow::TTypes::Matrix output_predictions); + tensorflow::TTypes::Matrix output_predictions, + Tensor* const output_leaf_index); }; } // namespace models diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc index 4ca18bedb1054ef64c6d4b25bbad04842bab1a6a..462a9ac86fe51d07cfb958d9be49bef84811a52e 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc @@ -62,7 +62,8 @@ TEST_F(MultipleAdditiveTreesTest, Empty) { tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", kNumThreadsSingleThreaded); MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, + /*output_leaf_index=*/nullptr); EXPECT_EQ(0, output_matrix(0, 0)); EXPECT_EQ(0, output_matrix(1, 0)); } @@ -99,17 +100,38 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + /*output_leaf_index=*/nullptr); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). } + // Normal case with leaf node. + { + // Initialize output leaf index tensor, since leaf index is positive in this + // case, initialize with the value of -1. Since there are 2 examples and + // there are 2 trees, initialize leaf output index by 2 * 2. + Tensor output_leaf_index_tensor(DT_INT32, TensorShape({2, 2})); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, + batch_features_, &threads, output_matrix, + &output_leaf_index_tensor); + EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix()( + 0, 0)); // 1st leaf for the first example + EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix()( + 1, 0)); // 1st leaf for the second example + EXPECT_FLOAT_EQ(2, output_leaf_index_tensor.matrix()( + 0, 1)); // 2nd leaf for the first example + EXPECT_FLOAT_EQ(1, output_leaf_index_tensor.matrix()( + 1, 1)); // 2nd leaf for the second example + } // Weighted case { DecisionTreeEnsembleConfig weighted = tree_ensemble_config; weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, - output_matrix); + output_matrix, nullptr); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0)); // -0.4 (bias) + 0.9 (leaf 1). @@ -118,21 +140,21 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { // Drop first tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1). } // Drop second tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias). EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias). } // Drop all trees. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0)); } @@ -172,7 +194,8 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1) @@ -184,7 +207,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, - output_matrix); + output_matrix, nullptr); // bias EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0)); // bias + leaf 2 @@ -197,7 +220,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Dropout first tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2) @@ -206,7 +229,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Dropout second tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias) EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias) @@ -215,7 +238,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Drop both trees. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1)); EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0)); @@ -258,7 +281,8 @@ TEST_F(MultipleAdditiveTreesTest, DenseLeaves) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + nullptr); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2) EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2) diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index 8ad97fedc923ac50bcaad86e0ba2c2e46df6821b..c120dd8a6c156ec9eb7ba0b6c552f5138bd21a16 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -295,7 +295,7 @@ WeightedQuantilesStream::GetQuantileSpecs( if (eps <= std::numeric_limits::epsilon()) { // Exact quantile computation at the expense of RAM. max_level = 1; - block_size = std::max(max_elements, 2LL); + block_size = std::max(max_elements, int64{2}); } else { // The bottom-most level will become full at most // (max_elements / block_size) times, the level above will become full @@ -315,7 +315,7 @@ WeightedQuantilesStream::GetQuantileSpecs( block_size = static_cast(ceil(max_level / eps)) + 1; } } - return std::make_tuple(max_level, std::max(block_size, 2LL)); + return std::make_tuple(max_level, std::max(block_size, int64{2})); } } // namespace quantiles diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index 7576856dc3a6d0b6681ee9745c875cf46d1e2960..a7e7bfc13cadcea4d29d33e0dbd955bdad6ffcb9 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -195,7 +195,7 @@ class WeightedQuantilesSummary { // designed to be cache-friendly. void Compress(int64 size_hint, double min_eps = 0) { // No-op if we're already within the size requirement. - size_hint = std::max(size_hint, 2LL); + size_hint = std::max(size_hint, int64{2}); if (entries_.size() <= size_hint) { return; } @@ -267,7 +267,7 @@ class WeightedQuantilesSummary { if (entries_.empty()) { return output; } - num_quantiles = std::max(num_quantiles, 2LL); + num_quantiles = std::max(num_quantiles, int64{2}); output.reserve(num_quantiles + 1); // Make successive rank queries to get boundaries. diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc index d66f645f62aba84261337eb37d6e3204930f8f15..6491d58794332e9417951753532e018aafb652b1 100644 --- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc @@ -40,6 +40,24 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) { return Status::OK(); } +static Status ApplyGradientTreesPredictionVerboseShapeFn(InferenceContext* c) { + string learner_config_str; + c->GetAttr("learner_config", &learner_config_str).IgnoreError(); + LearnerConfig learner_config; + ParseProtoUnlimited(&learner_config, learner_config_str); + + bool reduce_dim; + c->GetAttr("reduce_dim", &reduce_dim).IgnoreError(); + // Sets the shape of the output as a matrix. + c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim, + reduce_dim ? learner_config.num_classes() - 1 + : learner_config.num_classes())}); + c->set_output(1, {c->UnknownShape()}); + c->set_output(2, {c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)}); + return Status::OK(); +} + REGISTER_OP("GradientTreesPrediction") .Attr("learner_config: string") .Attr("num_dense_float_features: int >= 0") @@ -90,6 +108,58 @@ drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices and original weights of those trees during prediction. )doc"); +REGISTER_OP("GradientTreesPredictionVerbose") + .Attr("learner_config: string") + .Attr("num_dense_float_features: int >= 0") + .Attr("num_sparse_float_features: int >= 0") + .Attr("num_sparse_int_features: int >= 0") + .Attr("use_locking: bool = false") + .Attr("apply_dropout: bool") + .Attr("apply_averaging: bool") + .Attr("center_bias: bool") + .Attr("reduce_dim: bool") + .Input("tree_ensemble_handle: resource") + .Input("seed: int64") + .Input("dense_float_features: num_dense_float_features * float") + .Input("sparse_float_feature_indices: num_sparse_float_features * int64") + .Input("sparse_float_feature_values: num_sparse_float_features * float") + .Input("sparse_float_feature_shapes: num_sparse_float_features * int64") + .Input("sparse_int_feature_indices: num_sparse_int_features * int64") + .Input("sparse_int_feature_values: num_sparse_int_features * int64") + .Input("sparse_int_feature_shapes: num_sparse_int_features * int64") + .Output("predictions: float") + .Output("drop_out_tree_indices_weights: float") + .Output("leaf_index: int32") + .SetShapeFn(ApplyGradientTreesPredictionVerboseShapeFn) + .Doc(R"doc( +Runs multiple additive regression forests predictors on input instances +and computes the final prediction for each class, and outputs a matrix of +leaf ids per each tree in an ensemble. + +learner_config: Config for the learner of type LearnerConfig proto. Prediction +ops for now uses only LearningRateDropoutDrivenConfig config from the learner. +num_dense_float_features: Number of dense float features. +num_sparse_float_features: Number of sparse float features. +num_sparse_int_features: Number of sparse int features. +use_locking: Whether to use locking. +seed: random seed to be used for dropout. +reduce_dim: whether to reduce the dimension (legacy impl) or not. +apply_dropout: whether to apply dropout during prediction. +apply_averaging: whether averaging of tree ensembles should take place. If set +to true, will be based on AveragingConfig from learner_config. +tree_ensemble_handle: The handle to the tree ensemble. +dense_float_features: Rank 2 Tensors containing dense float feature values. +sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices. +sparse_float_feature_values: Rank 1 Tensors containing sparse float values. +sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes. +sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices. +sparse_int_feature_values: Rank 1 Tensors containing sparse int values. +sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes. +predictions: Rank 2 Tensor containing predictions per example per class. +drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices +leaf_index: tensor of rank 2 containing leaf ids for each tree where an instance ended up. +)doc"); + REGISTER_OP("GradientTreesPartitionExamples") .Attr("num_dense_float_features: int >= 0") .Attr("num_sparse_float_features: int >= 0") diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 5d0ebbf73ce1272b51a475f67984db3a181b7130..ca5c7f3d8c78a543c63fbfa9f7eb7c3d348f11b8 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -23,12 +23,6 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_OP("BuildDenseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -36,6 +30,12 @@ REGISTER_OP("BuildDenseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -73,6 +73,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -81,13 +92,6 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildSparseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -95,6 +99,13 @@ REGISTER_OP("BuildSparseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -133,6 +144,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -141,19 +163,19 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildCategoricalEqualitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("feature_ids: int64") .Input("gradients: float32") .Input("hessians: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -188,6 +210,17 @@ partition_ids: A rank 1 tensor of partition IDs. feature_ids: A rank 2 tensor of feature IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -196,4 +229,3 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); } // namespace tensorflow - // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 28834ef55bf8e1f32cc8f2380a4be3bf3824d8e1..5cd37ec67ec3bdefb6ea19049a7a12249162d45a 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import random + from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import split_info_pb2 from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops @@ -399,6 +401,65 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertAllClose(0.6, split_node.split.threshold) + def testMakeSparseSplitDefaultDirectionIsStable(self): + """Tests default direction is stable when no sparsity.""" + random.seed(1123) + for _ in range(50): + with self.test_session() as sess: + grad = random.random() + hessian = random.random() + # The data looks like the following (divide by the num of steps 2). + # Gradients | Partition | bucket ID | + # (grad, hessian) | 0 | -1 | + # And then 100 buckets of + # (grad/100, hessian/100), so there is no sparsity. + n_buckets = 100 + + # 1 for the overall sum, and 100 buckets. + partition_ids = array_ops.constant( + [0] * (n_buckets + 1), dtype=dtypes.int32) + # We have only 1 dimension in our sparse feature column. + + bucket_ids = [-1] + [n for n in range(100)] + bucket_ids = array_ops.constant(bucket_ids, dtype=dtypes.int64) + dimension_ids = array_ops.constant( + [0] * (n_buckets + 1), dtype=dtypes.int64) + bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1) + + gradients = [grad] + [grad / n_buckets] * n_buckets + gradients = array_ops.constant(gradients) + hessians = [hessian] + [hessian / n_buckets] * n_buckets + hessians = array_ops.constant(hessians) + + boundaries = [x * 1 for x in range(n_buckets + 1)] + bucket_boundaries = array_ops.constant(boundaries, dtype=dtypes.float32) + + partitions, gains, splits = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=2, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + bucket_boundaries=bucket_boundaries, + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + feature_column_group_id=0, + bias_feature_id=-1, + class_id=-1, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + partitions, gains, splits = (sess.run([partitions, gains, splits])) + self.assertAllEqual([0], partitions) + self.assertEqual(1, len(splits)) + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + self.assertTrue( + split_info.split_node.HasField( + 'sparse_float_binary_split_default_left')) + def testMakeMulticlassSparseSplit(self): """Tests split handler op.""" with self.test_session() as sess: diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index 7a5f329b7ab3216972180ccbb4c85f2537175422..843420968ac6a6716fdf6b4967146e131139f67c 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -20,6 +20,8 @@ from __future__ import print_function import abc import collections +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -60,6 +62,7 @@ def _move_tensors(tensors, device): """Moves a list of tensors to a device by concatenating/splitting them.""" # Reset the device setting to avoid weird interactions with device merging # logic. + zero = constant_op.constant(0, dtype=dtypes.int32) with ops.device(None): if all(tensor.shape == tensor_shape.scalar() for tensor in tensors): with ops.device(tensors[0].device): @@ -68,12 +71,11 @@ def _move_tensors(tensors, device): return array_ops.unstack(values) else: with ops.device(tensors[0].device): - sizes = array_ops.stack( - [array_ops.shape(tensor)[0] for tensor in tensors]) - values = array_ops.concat(tensors, axis=0) + sizes = array_ops.stack(array_ops.shape_n(tensors))[:, 0] + values = array_ops.concat(tensors, axis=zero) with ops.device(device): sizes = array_ops.unstack(sizes) - return list(array_ops.split(values, sizes, axis=0)) + return list(array_ops.split(values, sizes, axis=zero)) def _scheduled_stamp_resource_op_runner(batch, stamp): diff --git a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py index 58f0d36b0f78eeed6abcec1c4fa696f4ccffa615..7f6e55ae5888fc4ef50e34690d61c3ed303e971a 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py @@ -21,4 +21,5 @@ from __future__ import print_function from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_partition_examples from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction +from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction_verbose # pylint: enable=unused-import diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 50cc00afdcc77fedc9bf8c94a9a6fcf2a28ebde9..19b6b3296db394b07f57a25dbde187eb9195af38 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -201,3 +201,6 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): stamp_token=stamp_token, next_stamp_token=next_stamp_token) return result + + def resource(self): + return self._quantile_accumulator_handle diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index e53d86ec612f299c800753d67ceee79acb5db497..47698d45c81478f2b694aaadc603f742c44d5351 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -58,6 +58,7 @@ NUM_LAYERS_ATTEMPTED = "num_layers" NUM_TREES_ATTEMPTED = "num_trees" NUM_USED_HANDLERS = "num_used_handlers" USED_HANDLERS_MASK = "used_handlers_mask" +LEAF_INDEX = "leaf_index" _FEATURE_NAME_TEMPLATE = "%s_%d" @@ -71,18 +72,24 @@ def _get_column_by_index(tensor, indices): return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1]) -def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats, - used_handlers): +def _make_predictions_dict(stamp, + logits, + partition_ids, + ensemble_stats, + used_handlers, + leaf_index=None): """Returns predictions for the given logits and n_classes. Args: stamp: The ensemble stamp. - logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. - that contains predictions when no dropout was applied. + logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. that + contains predictions when no dropout was applied. partition_ids: A rank 1 `Tensor` with shape [batch_size]. ensemble_stats: A TreeEnsembleStatsOp result tuple. used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a - boolean mask.. + boolean mask. + leaf_index: A rank 2 `Tensor` with shape [batch_size, number of trees]. that + contains leaf id for each example prediction. Returns: A dict of predictions. @@ -95,6 +102,8 @@ def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats, result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask + if leaf_index is not None: + result[LEAF_INDEX] = leaf_index return result @@ -180,8 +189,7 @@ def extract_features(features, feature_columns, use_core_columns): elif isinstance(fc, feature_column_lib._EmbeddingColumn): # pylint: enable=protected-access transformed_features[fc.name] = fc_core.input_layer( - features, [fc], - weight_collections=[scope]) + features, [fc], weight_collections=[scope]) else: result = feature_column_ops.transform_features(features, [fc]) if len(result) > 1: @@ -269,7 +277,8 @@ class GradientBoostedDecisionTreeModel(object): features, logits_dimension, feature_columns=None, - use_core_columns=False): + use_core_columns=False, + output_leaf_index=False): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -277,13 +286,15 @@ class GradientBoostedDecisionTreeModel(object): num_ps_replicas: Number of parameter server replicas, can be 0. ensemble_handle: A handle to the ensemble variable. center_bias: Whether to center the bias before growing trees. - examples_per_layer: Number of examples to accumulate before growing - a tree layer. It can also be a function that computes the number of - examples based on the depth of the layer that's being built. + examples_per_layer: Number of examples to accumulate before growing a tree + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. learner_config: A learner config. features: `dict` of `Tensor` objects. logits_dimension: An int, the dimension of logits. feature_columns: A list of feature columns. + output_leaf_index: A boolean variable indicating whether to output leaf + index into predictions dictionary. Raises: ValueError: if inputs are not valid. @@ -334,10 +345,12 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._attempted_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="attempted_trees") self._finalized_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="finalized_trees") if not features: raise ValueError("Features dictionary must be specified.") @@ -354,9 +367,11 @@ class GradientBoostedDecisionTreeModel(object): self._sparse_int_indices = sparse_int_indices self._sparse_int_values = sparse_int_values self._sparse_int_shapes = sparse_int_shapes - self._reduce_dim = (self._learner_config.multi_class_strategy == - learner_pb2.LearnerConfig.TREE_PER_CLASS and - learner_config.num_classes == 2) + self._reduce_dim = ( + self._learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.TREE_PER_CLASS and + learner_config.num_classes == 2) + self._output_leaf_index = output_leaf_index def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): """Runs prediction and returns a dictionary of the prediction results. @@ -374,8 +389,8 @@ class GradientBoostedDecisionTreeModel(object): ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, ensemble_stamp) num_handlers = ( - len(self._dense_floats) + len(self._sparse_float_shapes) + - len(self._sparse_int_shapes)) + len(self._dense_floats) + len(self._sparse_float_shapes) + len( + self._sparse_int_shapes)) # Used during feature selection. used_handlers = model_ops.tree_ensemble_used_handlers( ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers) @@ -386,22 +401,44 @@ class GradientBoostedDecisionTreeModel(object): # Make sure ensemble stats run. This will check that the ensemble has # the right stamp. with ops.control_dependencies(ensemble_stats): - predictions, _ = prediction_ops.gradient_trees_prediction( - ensemble_handle, - seed, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - learner_config=self._learner_config_serialized, - apply_dropout=apply_dropout, - apply_averaging=mode != learn.ModeKeys.TRAIN, - use_locking=True, - center_bias=self._center_bias, - reduce_dim=self._reduce_dim) + leaf_index = None + # Only used in infer (predict), not used in train and eval. + if self._output_leaf_index and mode == learn.ModeKeys.INFER: + predictions, _, leaf_index = ( + prediction_ops).gradient_trees_prediction_verbose( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) + else: + leaf_index = None + predictions, _ = prediction_ops.gradient_trees_prediction( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) partition_ids = prediction_ops.gradient_trees_partition_examples( ensemble_handle, self._dense_floats, @@ -414,7 +451,7 @@ class GradientBoostedDecisionTreeModel(object): use_locking=True) return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, - ensemble_stats, used_handlers) + ensemble_stats, used_handlers, leaf_index) def predict(self, mode): """Returns predictions given the features and mode. @@ -432,8 +469,9 @@ class GradientBoostedDecisionTreeModel(object): # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device # as the model. If not, we create a copy of the model on the worker. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) if not input_deps: raise ValueError("No input tensors for prediction.") @@ -457,8 +495,8 @@ class GradientBoostedDecisionTreeModel(object): # Determine whether the local ensemble is stale and update it if needed. def _refresh_local_ensemble_fn(): - # Serialize the model from parameter server after reading all inputs. - with ops.control_dependencies(input_deps): + # Serialize the model from parameter server after reading the inputs. + with ops.control_dependencies([input_deps[0]]): (ensemble_stamp, serialized_model) = ( model_ops.tree_ensemble_serialize(self._ensemble_handle)) @@ -500,8 +538,9 @@ class GradientBoostedDecisionTreeModel(object): ValueError: if inputs are not valid. """ # Get the worker device from input dependencies. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) worker_device = input_deps[0].device # Get tensors relevant for training and form the loss. @@ -517,7 +556,7 @@ class GradientBoostedDecisionTreeModel(object): aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - class_id = -1 + class_id = constant_op.constant(-1, dtype=dtypes.int32) # Handle different multiclass strategies. if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS: # We build one vs rest trees. @@ -571,31 +610,39 @@ class GradientBoostedDecisionTreeModel(object): # Get the weights for each example for quantiles calculation, weights = self._get_weights(hessian_shape, squeezed_hessians) - regularization_config = self._learner_config.regularization - min_node_weight = self._learner_config.constraints.min_node_weight # Create all handlers ensuring resources are evenly allocated across PS. fc_name_idx = 0 handlers = [] init_stamp_token = constant_op.constant(0, dtype=dtypes.int64) + l1_regularization = constant_op.constant( + self._learner_config.regularization.l1, dtypes.float32) + l2_regularization = constant_op.constant( + self._learner_config.regularization.l2, dtypes.float32) + tree_complexity_regularization = constant_op.constant( + self._learner_config.regularization.tree_complexity, dtypes.float32) + min_node_weight = constant_op.constant( + self._learner_config.constraints.min_node_weight, dtypes.float32) + epsilon = 0.01 + num_quantiles = 100 + strategy_tensor = constant_op.constant(strategy) with ops.device(self._get_replica_device_setter(worker_device)): # Create handlers for dense float columns for dense_float_column_idx in range(len(self._dense_floats)): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.DenseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=dense_float_column_idx, - epsilon=0.01, - num_quantiles=100, + epsilon=epsilon, + num_quantiles=num_quantiles, dense_float_column=self._dense_floats[dense_float_column_idx], name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -604,14 +651,13 @@ class GradientBoostedDecisionTreeModel(object): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.SparseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=sparse_float_column_idx, - epsilon=0.01, - num_quantiles=100, + epsilon=epsilon, + num_quantiles=num_quantiles, sparse_float_column=sparse_tensor.SparseTensor( self._sparse_float_indices[sparse_float_column_idx], self._sparse_float_values[sparse_float_column_idx], @@ -619,7 +665,7 @@ class GradientBoostedDecisionTreeModel(object): name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -628,10 +674,9 @@ class GradientBoostedDecisionTreeModel(object): fc_name = self._fc_names[fc_name_idx] handlers.append( categorical_split_handler.EqualitySplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=sparse_int_column_idx, sparse_int_column=sparse_tensor.SparseTensor( @@ -641,7 +686,7 @@ class GradientBoostedDecisionTreeModel(object): name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -694,11 +739,11 @@ class GradientBoostedDecisionTreeModel(object): name="continue_centering", trainable=False) stats_update_ops.append( - control_flow_ops.cond(continue_centering, - self._make_update_bias_stats_fn( - ensemble_stamp, predictions, gradients, - bias_stats_accumulator), - control_flow_ops.no_op)) + control_flow_ops.cond( + continue_centering, + self._make_update_bias_stats_fn(ensemble_stamp, predictions, + gradients, bias_stats_accumulator), + control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -720,8 +765,8 @@ class GradientBoostedDecisionTreeModel(object): shape=[len(handlers)], seed=[seed + 1, 1]) active_handlers = array_ops.stack( [active_handlers_current_layer, active_handlers_next_layer], axis=1) - active_handlers = (active_handlers < - self._learner_config.feature_fraction_per_level) + active_handlers = ( + active_handlers < self._learner_config.feature_fraction_per_level) elif subsampling_type == "feature_fraction_per_tree": seed = predictions_dict[NUM_TREES_ATTEMPTED] active_handlers_current_layer = stateless.stateless_random_uniform( @@ -729,9 +774,12 @@ class GradientBoostedDecisionTreeModel(object): active_handlers_current_layer = ( active_handlers_current_layer < self._learner_config.feature_fraction_per_tree) - active_handlers = array_ops.stack([ - active_handlers_current_layer, - array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1) + active_handlers = array_ops.stack( + [ + active_handlers_current_layer, + array_ops.ones([len(handlers)], dtype=dtypes.bool) + ], + axis=1) else: active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool) @@ -760,6 +808,7 @@ class GradientBoostedDecisionTreeModel(object): empty_hessians = constant_op.constant( [], dtype=dtypes.float32, shape=empty_hess_shape) + active_handlers = array_ops.unstack(active_handlers, axis=0) for handler_idx in range(len(handlers)): handler = handlers[handler_idx] is_active = active_handlers[handler_idx] @@ -901,7 +950,6 @@ class GradientBoostedDecisionTreeModel(object): "DecisionTreeEnsembleResourceHandleOp", "StatsAccumulatorScalarResourceHandleOp", "StatsAccumulatorTensorResourceHandleOp", - "QuantileStreamResourceHandleOp", ] ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) return device_setter.replica_device_setter( @@ -971,7 +1019,7 @@ class GradientBoostedDecisionTreeModel(object): # This is a workaround for the slowness of graph building in tf.cond. # See (b/36554864). split_sizes = array_ops.reshape( - array_ops.shape_n(partition_ids_list), [-1]) + array_ops.shape_n(partition_ids_list), [len(partition_ids_list)]) partition_ids = array_ops.concat(partition_ids_list, axis=0) gains = array_ops.concat(gains_list, axis=0) split_infos = array_ops.concat(split_info_list, axis=0) @@ -1036,8 +1084,11 @@ class GradientBoostedDecisionTreeModel(object): # Update ensemble. update_ops = [are_all_splits_ready] - update_model = control_flow_ops.cond(continue_centering, _center_bias_fn, - _grow_ensemble_fn) + if self._center_bias: + update_model = control_flow_ops.cond(continue_centering, + _center_bias_fn, _grow_ensemble_fn) + else: + update_model = _grow_ensemble_fn() update_ops.append(update_model) # Update ensemble stats. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index f9c22283b7f5136777bfa60a12c94974adfbd245..e3d4397fadcbaf148f7f6cfaca13e850639786cf 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -19,19 +19,15 @@ from __future__ import division from __future__ import print_function from google.protobuf import text_format - from tensorflow.contrib import layers from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch from tensorflow.contrib.boosted_trees.python.utils import losses - -from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn - - +from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util @@ -97,8 +93,8 @@ class GbdtTest(test_util.TensorFlowTestCase): array_ops.zeros([2], dtypes.int64)) features["sparse_int"] = sparse_tensor.SparseTensor( array_ops.zeros([2, 2], dtypes.int64), - array_ops.zeros([2], dtypes.int64), - array_ops.zeros([2], dtypes.int64)) + array_ops.zeros([2], dtypes.int64), array_ops.zeros([2], + dtypes.int64)) (fc_names, dense_floats, sparse_float_indices, sparse_float_values, sparse_float_shapes, sparse_int_indices, sparse_int_values, sparse_int_shapes) = ( @@ -139,8 +135,8 @@ class GbdtTest(test_util.TensorFlowTestCase): array_ops.zeros([2], dtypes.int64)) features["sparse_categorical"] = sparse_tensor.SparseTensor( array_ops.zeros([2, 2], dtypes.int64), - array_ops.zeros( - [2], dtypes.string), array_ops.zeros([2], dtypes.int64)) + array_ops.zeros([2], dtypes.string), array_ops.zeros([2], + dtypes.int64)) feature_columns = set() feature_columns.add(layers.real_valued_column("dense_float")) feature_columns.add( @@ -235,7 +231,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -316,6 +313,113 @@ class GbdtTest(test_util.TensorFlowTestCase): }""" self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefSparseAndDense(self): + """Tests the train function with sparse and dense features.""" + with self.test_session() as sess: + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + features["sparse_float"] = sparse_tensor.SparseTensor( + array_ops.zeros([2, 2], dtypes.int64), + array_ops.zeros([2], dtypes.float32), + array_ops.constant([4, 1], dtypes.int64)) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp, + "num_trees": 12, + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [0.1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + sparse_float_binary_split_default_right { + split{ + left_id: 1 + right_id: 2 + } + } + node_metadata { + gain: 1.125 + } + } + nodes { + leaf { + vector { + value: 1.0 + } + } + } + nodes { + leaf { + vector { + value: -0.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefScalingNumberOfExamples(self): """Tests the train function running on chief without bias centering.""" with self.test_session() as sess: @@ -339,7 +443,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=num_examples_fn, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -442,7 +547,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -513,7 +619,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -576,7 +683,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -622,7 +730,8 @@ class GbdtTest(test_util.TensorFlowTestCase): with self.test_session() as sess: # Create ensemble with one bias node. ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - text_format.Merge(""" + text_format.Merge( + """ trees { nodes { leaf { @@ -659,16 +768,129 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) # Create predict op. mode = model_fn.ModeKeys.EVAL predictions_dict = sess.run(gbdt_model.predict(mode)) self.assertEquals(predictions_dict["ensemble_stamp"], 3) - self.assertAllClose(predictions_dict["predictions"], [[0.25], [0.25], - [0.25], [0.25]]) + self.assertAllClose(predictions_dict["predictions"], + [[0.25], [0.25], [0.25], [0.25]]) self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0]) + def testPredictFnWithLeafIndexAdvancedLeft(self): + """Tests the predict function with output leaf ids.""" + with self.test_session() as sess: + # Create ensemble with one bias node. + ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + dense_float_binary_split { + threshold: 1.0 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.15 + } + } + } + } + trees { + nodes { + dense_float_binary_split { + threshold: 0.99 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 00 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.23 + } + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + }""", ensemble_config) + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=3, + tree_ensemble_config=ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.constant( + [[0.0], [1.0], [1.1], [2.0]], dtype=dtypes.float32) + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=False, + num_ps_replicas=0, + center_bias=True, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features, + output_leaf_index=True) + + # Create predict op. + mode = model_fn.ModeKeys.INFER + predictions_dict = sess.run(gbdt_model.predict(mode)) + self.assertEquals(predictions_dict["ensemble_stamp"], 3) + # here are how the numbers in expected results are calculated, + # 0.5 = 0.25 + 0.25 + # 0.48 = 0.25 + 0.23 + # 0.38 = 0.15 + 0.23 + # 0.38 = 0.15 + 0.23 + self.assertAllClose(predictions_dict["predictions"], + [[0.5], [0.48], [0.38], [0.38]]) + self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0]) + self.assertAllClose(predictions_dict["leaf_index"], + [[1, 1], [1, 2], [2, 2], [2, 2]]) + def testTrainFnMulticlassFullHessian(self): """Tests the GBDT train for multiclass full hessian.""" with self.test_session() as sess: @@ -698,7 +920,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -801,7 +1024,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -893,8 +1117,8 @@ class GbdtTest(test_util.TensorFlowTestCase): learner_config.constraints.max_tree_depth = 1 learner_config.constraints.min_node_weight = 0 features = { - "dense_float": array_ops.constant( - [[1.0], [1.5], [2.0]], dtypes.float32), + "dense_float": + array_ops.constant([[1.0], [1.5], [2.0]], dtypes.float32), } gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( @@ -904,7 +1128,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) batch_size = 3 predictions = array_ops.constant( @@ -986,7 +1211,8 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertAllClose( 0.893284678459, output.trees[0].nodes[2].leaf.sparse_vector.value[0], - atol=1e-4, rtol=1e-4) + atol=1e-4, + rtol=1e-4) def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self): """Tests the train function running on chief with feature selection.""" @@ -1230,9 +1456,9 @@ class GbdtTest(test_util.TensorFlowTestCase): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree = tree_ensemble_config.trees.add() - _set_float_split(tree.nodes.add() - .sparse_float_binary_split_default_right.split, 2, 4.0, - 1, 2) + _set_float_split( + tree.nodes.add().sparse_float_binary_split_default_right.split, 2, + 4.0, 1, 2) _append_to_leaf(tree.nodes.add().leaf, 0, 0.5) _append_to_leaf(tree.nodes.add().leaf, 1, 1.2) tree_ensemble_config.tree_weights.append(1.0) @@ -1241,7 +1467,8 @@ class GbdtTest(test_util.TensorFlowTestCase): metadata.num_layers_grown = 1 tree_ensemble_config = tree_ensemble_config.SerializeToString() ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, tree_ensemble_config=tree_ensemble_config, + stamp_token=0, + tree_ensemble_config=tree_ensemble_config, name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 @@ -1333,5 +1560,6 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertEquals(output.growing_metadata.num_layers_attempted, 2) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index e529b25b3caa1e9f7e08522de9e08401ef639eca..8ae493ba998bd882b5ef946f927ec1882d91f61d 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -14,25 +14,37 @@ # ============================================================================== """Tools for working with object-based checkpoints. - -For creating and managing dependencies: -@@CheckpointableObjectGraph +Visualization and inspection: @@dot_graph_from_checkpoint @@object_metadata + +Managing dependencies: +@@Checkpointable +@@CheckpointableObjectGraph @@NoDependency @@split_dependency + +Checkpointable data structures: +@@List +@@Mapping +@@UniqueNameTracker """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph -from tensorflow.python.training.checkpointable import NoDependency -from tensorflow.python.training.checkpointable_utils import object_metadata +from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.base import NoDependency +from tensorflow.python.training.checkpointable.data_structures import List +from tensorflow.python.training.checkpointable.data_structures import Mapping +from tensorflow.python.training.checkpointable.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(module_name=__name__) + diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index a5681ffa61d07ef29d0a0862db9736a210c8e26e..7b200a29bf60087d6da1010b0be05c04faec80cd 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -8,8 +8,35 @@ py_library( name = "checkpoint", srcs_version = "PY2AND3", deps = [ + ":containers", ":split_dependency", ":visualize", + "//tensorflow/python/training/checkpointable:data_structures", + ], +) + +py_library( + name = "containers", + srcs = ["containers.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:data_structures", + ], +) + +py_test( + name = "containers_test", + srcs = ["containers_test.py"], + deps = [ + ":containers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", + "@six_archive//:six", ], ) @@ -21,6 +48,7 @@ py_library( deps = [ "//tensorflow/python:control_flow_ops", "//tensorflow/python:training", + "//tensorflow/python/training/checkpointable:base", ], ) @@ -32,8 +60,9 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", "//tensorflow/python/eager:test", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", ], ) @@ -44,6 +73,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", ], ) @@ -52,10 +83,13 @@ py_test( srcs = ["visualize_test.py"], deps = [ ":visualize", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:constant_op", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//tensorflow/python/keras:engine", + "//tensorflow/python/keras:layers", + "//tensorflow/python/training/checkpointable:util", ], ) diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3d5312993740636709cb732c0b8e3e2626262d --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -0,0 +1,80 @@ +"""Checkpointable data structures.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import data_structures + + +class UniqueNameTracker(data_structures.CheckpointableDataStructure): + """Adds dependencies on checkpointable objects with name hints. + + Useful for creating dependencies with locally unique names. + + Example usage: + ```python + class SlotManager(tf.contrib.checkpoint.Checkpointable): + + def __init__(self): + # Create a dependency named "slotdeps" on the container. + self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x" + slots.append(slotdeps.track(tfe.Variable(4.), "y")) + slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1" + ``` + """ + + def __init__(self): + super(UniqueNameTracker, self).__init__() + self._maybe_initialize_checkpointable() + self._name_counts = {} + + def track(self, checkpointable, base_name): + """Add a dependency on `checkpointable`. + + Args: + checkpointable: An object to add a checkpoint dependency on. + base_name: A name hint, which is uniquified to determine the dependency + name. + Returns: + `checkpointable`, for chaining. + Raises: + ValueError: If `checkpointable` is not a checkpointable object. + """ + + if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase): + raise ValueError( + ("Expected a checkpointable value, got %s which does not inherit " + "from CheckpointableBase.") % (checkpointable,)) + + def _format_name(prefix, number): + if number > 0: + return "%s_%d" % (prefix, number) + else: + return prefix + + count = self._name_counts.get(base_name, 0) + candidate = _format_name(base_name, count) + while self._lookup_dependency(candidate) is not None: + count += 1 + candidate = _format_name(base_name, count) + self._name_counts[base_name] = count + 1 + self._track_value(checkpointable, name=candidate) + return checkpointable diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3717d7f583ffdc205a279d45df60cddbc5cbf08e --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -0,0 +1,108 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import six + +from tensorflow.contrib.checkpoint.python import containers +from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils + + +class UniqueNameTrackerTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testNames(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + x1 = resource_variable_ops.ResourceVariable(2.) + x2 = resource_variable_ops.ResourceVariable(3.) + x3 = resource_variable_ops.ResourceVariable(4.) + y = resource_variable_ops.ResourceVariable(5.) + slots = containers.UniqueNameTracker() + slots.track(x1, "x") + slots.track(x2, "x") + slots.track(x3, "x_1") + slots.track(y, "y") + self.evaluate((x1.initializer, x2.initializer, x3.initializer, + y.initializer)) + save_root = checkpointable_utils.Checkpoint(slots=slots) + save_path = save_root.save(checkpoint_prefix) + + restore_slots = checkpointable.Checkpointable() + restore_root = checkpointable_utils.Checkpoint( + slots=restore_slots) + status = restore_root.restore(save_path) + restore_slots.x = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.y = resource_variable_ops.ResourceVariable(0.) + status.assert_consumed().run_restore_ops() + self.assertEqual(2., self.evaluate(restore_slots.x)) + self.assertEqual(3., self.evaluate(restore_slots.x_1)) + self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) + self.assertEqual(5., self.evaluate(restore_slots.y)) + + @test_util.run_in_graph_and_eager_modes() + def testExample(self): + class SlotManager(checkpointable.Checkpointable): + + def __init__(self): + self.slotdeps = containers.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(3.), "x")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(4.), "y")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(5.), "x")) + self.slots = slots + + manager = SlotManager() + self.evaluate([v.initializer for v in manager.slots]) + checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpoint.save(checkpoint_prefix) + metadata = checkpointable_utils.object_metadata(save_path) + dependency_names = [] + for node in metadata.nodes: + for child in node.children: + dependency_names.append(child.local_name) + six.assertCountEqual( + self, + dependency_names, + ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"]) + + @test_util.run_in_graph_and_eager_modes() + def testLayers(self): + tracker = containers.UniqueNameTracker() + tracker.track(layers.Dense(3), "dense") + tracker.layers[0](array_ops.zeros([1, 1])) + self.assertEqual(2, len(tracker.trainable_weights)) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py index 3aec8c96e90440d6da00d95cffc34bd53ec7164f..7e77453f3d848c2e321ed2ba66917a742d95459a 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -20,8 +20,8 @@ from __future__ import print_function import functools from tensorflow.python.ops import control_flow_ops -from tensorflow.python.training import checkpointable as checkpointable from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training.checkpointable import base as checkpointable class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index f1d9d19b047ee69281cf8bdba38a28dc87947e38..69dc0b9be2d5548852c37552a64a0d31c9557b43 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -23,8 +23,8 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training import checkpointable -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils def _split_variable_closure(variable): diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py index 9a3b23bb2c30ee601f5f94da31ad182399a04e4f..bac071c4cff383f60b707b6e42c13faf5e0ac948 100644 --- a/tensorflow/contrib/checkpoint/python/visualize.py +++ b/tensorflow/contrib/checkpoint/python/visualize.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow -from tensorflow.python.training import checkpointable -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils def dot_graph_from_checkpoint(save_path): diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py index 1d9ab789235cb964521315b4864563f89745ae75..583e3bc442893d825c337d73fb999d1e586738a1 100644 --- a/tensorflow/contrib/checkpoint/python/visualize_test.py +++ b/tensorflow/contrib/checkpoint/python/visualize_test.py @@ -24,11 +24,11 @@ from tensorflow.contrib.checkpoint.python import visualize from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import adam -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import util as checkpointable_utils try: import pydot # pylint: disable=g-import-not-at-top diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index f3a75e8688ece19a6e6fd53ee9faf7f4144d76cf..42ba368531468b789a87429f88ca84937f9b909d 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -15,7 +15,10 @@ load( ) tf_gen_op_libs( - op_lib_names = ["bigquery_reader_ops"], + op_lib_names = [ + "bigquery_reader_ops", + "gcs_config_ops", + ], deps = [ "//tensorflow/core:lib", ], @@ -28,15 +31,25 @@ tf_gen_op_wrapper_py( deps = [":bigquery_reader_ops_op_lib"], ) +tf_gen_op_wrapper_py( + name = "gen_gcs_config_ops", + out = "python/ops/gen_gcs_config_ops.py", + require_shape_functions = True, + visibility = ["//tensorflow:internal"], + deps = [":gcs_config_ops_op_lib"], +) + py_library( name = "cloud_py", srcs = [ "__init__.py", "python/ops/bigquery_reader_ops.py", + "python/ops/gcs_config_ops.py", ], srcs_version = "PY2AND3", deps = [ ":gen_bigquery_reader_ops", + ":gen_gcs_config_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:io_ops", "//tensorflow/python:util", diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index 8870264b95dfd9f8c4b1655c475fe23e0639924f..a6e13ea3ae938444b9ead0772e52fb8797a847da 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -20,9 +20,15 @@ from __future__ import print_function # pylint: disable=line-too-long,wildcard-import from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * +from tensorflow.contrib.cloud.python.ops.gcs_config_ops import * # pylint: enable=line-too-long,wildcard-import from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['BigQueryReader'] +_allowed_symbols = [ + 'BigQueryReader', + 'ConfigureColabSession', + 'ConfigureGcs', + 'ConfigureGcsHook', +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index ff46f0daa80a70badedf73e15bfaf4dca85fdd89..40160706f70e8fa8323005dd183770ed51c8c415 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -73,3 +73,17 @@ tf_proto_library( srcs = ["bigquery_table_partition.proto"], cc_api_version = 2, ) + +tf_kernel_library( + name = "gcs_config_ops", + srcs = ["gcs_config_ops.cc"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform/cloud:curl_http_request", + "//tensorflow/core/platform/cloud:gcs_file_system", + "//tensorflow/core/platform/cloud:oauth_client", + "@jsoncpp_git//:jsoncpp", + ], +) diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..648a219fb87a6ebc64767a7da780013ef6b95443 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc @@ -0,0 +1,205 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "include/json/json.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/cloud/gcs_file_system.h" +#include "tensorflow/core/platform/cloud/oauth_client.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { + +// The default initial delay between retries with exponential backoff. +constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec + +// The minimum time delta between now and the token expiration time +// for the token to be re-used. +constexpr int kExpirationTimeMarginSec = 60; + +// The URL to retrieve the auth bearer token via OAuth with a refresh token. +constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token"; + +// The URL to retrieve the auth bearer token via OAuth with a private key. +constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token"; + +// The authentication token scope to request. +constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; + +Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) { + DCHECK(fs != nullptr); + *fs = nullptr; + + FileSystem* filesystem = nullptr; + TF_RETURN_IF_ERROR( + ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem)); + if (filesystem == nullptr) { + return errors::FailedPrecondition("The GCS file system is not registered."); + } + + *fs = dynamic_cast(filesystem); + if (*fs == nullptr) { + return errors::Internal( + "The filesystem registered under the 'gs://' scheme was not a " + "tensorflow::RetryingGcsFileSystem*."); + } + return Status::OK(); +} + +template +Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); +} + +// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem. +class GcsCredentialsOpKernel : public OpKernel { + public: + explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + string json_string; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "json", &json_string)); + + Json::Value json; + Json::Reader reader; + std::stringstream json_stream(json_string); + OP_REQUIRES(ctx, reader.parse(json_stream, json), + errors::InvalidArgument("Could not parse json: ", json_string)); + + OP_REQUIRES( + ctx, json.isMember("refresh_token") || json.isMember("private_key"), + errors::InvalidArgument("JSON format incompatible; did not find fields " + "`refresh_token` or `private_key`.")); + + auto provider = + tensorflow::MakeUnique(json, ctx->env()); + + // Test getting a token + string dummy_token; + OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token)); + OP_REQUIRES(ctx, !dummy_token.empty(), + errors::InvalidArgument( + "Could not retrieve a token with the given credentials.")); + + // Set the provider. + gcs->underlying()->SetAuthProvider(std::move(provider)); + } + + private: + class ConstantAuthProvider : public AuthProvider { + public: + ConstantAuthProvider(const Json::Value& json, + std::unique_ptr oauth_client, Env* env, + int64 initial_retry_delay_usec) + : json_(json), + oauth_client_(std::move(oauth_client)), + env_(env), + initial_retry_delay_usec_(initial_retry_delay_usec) {} + + ConstantAuthProvider(const Json::Value& json, Env* env) + : ConstantAuthProvider(json, tensorflow::MakeUnique(), env, + kInitialRetryDelayUsec) {} + + ~ConstantAuthProvider() override {} + + Status GetToken(string* token) override { + mutex_lock l(mu_); + const uint64 now_sec = env_->NowSeconds(); + + if (!current_token_.empty() && + now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) { + *token = current_token_; + return Status::OK(); + } + if (json_.isMember("refresh_token")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( + json_, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_)); + } else if (json_.isMember("private_key")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( + json_, kOAuthV4Url, kOAuthScope, ¤t_token_, + &expiration_timestamp_sec_)); + } else { + return errors::FailedPrecondition( + "Unexpected content of the JSON credentials file."); + } + + *token = current_token_; + return Status::OK(); + } + + private: + Json::Value json_; + std::unique_ptr oauth_client_; + Env* env_; + + mutex mu_; + string current_token_ GUARDED_BY(mu_); + uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0; + + // The initial delay for exponential backoffs when retrying failed calls. + const int64 initial_retry_delay_usec_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider); + }; +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureCredentials").Device(DEVICE_CPU), + GcsCredentialsOpKernel); + +class GcsBlockCacheOpKernel : public OpKernel { + public: + explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + size_t max_cache_size, block_size, max_staleness; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_cache_size", + &max_cache_size)); + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_size", &block_size)); + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "max_staleness", &max_staleness)); + + if (gcs->underlying()->block_size() == block_size && + gcs->underlying()->max_bytes() == max_cache_size && + gcs->underlying()->max_staleness() == max_staleness) { + LOG(INFO) << "Skipping resetting the GCS block cache."; + return; + } + gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size, + max_staleness); + } +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureBlockCache").Device(DEVICE_CPU), + GcsBlockCacheOpKernel); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/ops/gcs_config_ops.cc b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cf85f5f1811d873075b6d2e1931d8badfd6e32c --- /dev/null +++ b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("GcsConfigureCredentials") + .Input("json: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Configures the credentials used by the GCS client of the local TF runtime. + +The json input can be of the format: + +1. Refresh Token: +{ + "client_id": "", + "client_secret": "", + "refresh_token: "", + "type": "authorized_user", +} + +2. Service Account: +{ + "type": "service_account", + "project_id": "", + "private_key_id": "", + "private_key": "------BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY------\n", + "client_email": "@.iam.gserviceaccount.com", + "client_id": "", + # Some additional fields elided +} + +Note the credentials established through this method are shared across all +sessions run on this runtime. + +Note be sure to feed the inputs to this op to ensure the credentials are not +stored in a constant op within the graph that might accidentally be checkpointed +or in other ways be persisted or exfiltrated. +)doc"); + +REGISTER_OP("GcsConfigureBlockCache") + .Input("max_cache_size: uint64") + .Input("block_size: uint64") + .Input("max_staleness: uint64") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Re-configures the GCS block cache with the new configuration values. + +If the values are the same as already configured values, this op is a no-op. If +they are different, the current contents of the block cache is dropped, and a +new block cache is created fresh. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8c5acb31af69b4f738a13c6548cdd31947d71a --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -0,0 +1,188 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GCS file system configuration for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.contrib.cloud.python.ops import gen_gcs_config_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.training import training + + +# @tf_export('contrib.cloud.BlockCacheParams') +class BlockCacheParams(object): + """BlockCacheParams is a struct used for configuring the GCS Block Cache.""" + + def __init__(self, block_size=None, max_bytes=None, max_staleness=None): + self._block_size = block_size or 128 * 1024 * 1024 + self._max_bytes = max_bytes or 2 * self._block_size + self._max_staleness = max_staleness or 0 + + @property + def block_size(self): + return self._block_size + + @property + def max_bytes(self): + return self._max_bytes + + @property + def max_staleness(self): + return self._max_staleness + + +# @tf_export('contrib.cloud.ConfigureGcsHook') +class ConfigureGcsHook(training.SessionRunHook): + """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Example: + + ``` + sess = tf.Session() + refresh_token = raw_input("Refresh token: ") + client_secret = raw_input("Client secret: ") + client_id = "" + creds = { + "client_id": client_id, + "refresh_token": refresh_token, + "client_secret": client_secret, + "type": "authorized_user", + } + tf.contrib.cloud.configure_gcs(sess, credentials=creds) + ``` + + """ + + def _verify_dictionary(self, creds_dict): + if 'refresh_token' in creds_dict or 'private_key' in creds_dict: + return True + return False + + def __init__(self, credentials=None, block_cache=None): + """Constructs a ConfigureGcsHook. + + Args: + credentials: A json-formatted string. + block_cache: A `BlockCacheParams` + + Raises: + ValueError: If credentials is improperly formatted or block_cache is not a + BlockCacheParams. + """ + if credentials is not None: + if isinstance(credentials, str): + try: + data = json.loads(credentials) + except ValueError as e: + raise ValueError('credentials was not a well formed JSON string.', e) + if not self._verify_dictionary(data): + raise ValueError( + 'credentials has neither a "refresh_token" nor a "private_key" ' + 'field.') + elif isinstance(credentials, dict): + if not self._verify_dictionary(credentials): + raise ValueError('credentials has neither a "refresh_token" nor a ' + '"private_key" field.') + credentials = json.dumps(credentials) + else: + raise ValueError('credentials is of an unknown type') + + self._credentials = credentials + + if block_cache and not isinstance(block_cache, BlockCacheParams): + raise ValueError('block_cache must be an instance of BlockCacheParams.') + self._block_cache = block_cache + + def begin(self): + if self._credentials: + self._credentials_placeholder = array_ops.placeholder(dtypes.string) + self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials( + self._credentials_placeholder) + if self._block_cache: + self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=self._block_cache.max_bytes, + block_size=self._block_cache.block_size, + max_staleness=self._block_cache.max_staleness) + + def after_create_session(self, session, coord): + del coord + if self._credentials_op: + session.run( + self._credentials_op, + feed_dict={self._credentials_placeholder: self._credentials}) + if self._block_cache_op: + session.run(self._block_cache_op) + + +def configure_gcs(session, credentials=None, block_cache=None, device=None): + """Configures the GCS file system for a given a session. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Args: + session: A `tf.Session` session that should be used to configure the GCS + file system. + credentials: [Optional.] A JSON string + block_cache: [Optional.] A BlockCacheParams to configure the block cache . + device: [Optional.] The device to place the configure ops. + """ + + def configure(credentials, block_cache): + """Helper function to actually configure GCS.""" + if credentials: + if isinstance(credentials, dict): + credentials = json.dumps(credentials) + placeholder = array_ops.placeholder(dtypes.string) + op = gen_gcs_config_ops.gcs_configure_credentials(placeholder) + session.run(op, feed_dict={placeholder: credentials}) + if block_cache: + op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=block_cache.max_bytes, + block_size=block_cache.block_size, + max_staleness=block_cache.max_staleness) + session.run(op) + + if device: + with ops.device(device): + return configure(credentials, block_cache) + return configure(credentials, block_cache) + + +def configure_colab_session(session): + """ConfigureColabSession configures the GCS file system in Colab. + + Args: + session: A `tf.Session` session. + """ + # Read from the application default credentials (adc). + with open('/content/datalab/adc.json') as f: + data = json.load(f) + configure_gcs(session, credentials=data) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 1403483d287041b02dfbf538f7e7ddee11662f47..8f521ffee4d31e090c13bac98290656d6e1d330e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -36,6 +36,9 @@ except ImportError: _GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' +_ENDPOINTS_SEPARATOR = ',' +_DEFAULT_ENV_VARIABLE = 'TPU_NAME' +_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' class TPUClusterResolver(ClusterResolver): @@ -67,8 +70,18 @@ class TPUClusterResolver(ClusterResolver): return _GKE_ENV_VARIABLE in os.environ @staticmethod - def _gkeMaster(): - return os.environ[_GKE_ENV_VARIABLE].split(',')[0] + def _gkeEndpoints(): + return os.environ[_GKE_ENV_VARIABLE] + + @staticmethod + def _envVarFallback(): + if _DEFAULT_ENV_VARIABLE in os.environ: + return os.environ[_DEFAULT_ENV_VARIABLE] + return None + + @staticmethod + def _discoveryUrl(): + return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) def __init__(self, tpu=None, @@ -78,7 +91,8 @@ class TPUClusterResolver(ClusterResolver): coordinator_name=None, coordinator_address=None, credentials='default', - service=None): + service=None, + discovery_url=None): """Creates a new TPUClusterResolver object. The ClusterResolver will then use the parameters to query the Cloud TPU APIs @@ -108,6 +122,11 @@ class TPUClusterResolver(ClusterResolver): service: The GCE API object returned by the googleapiclient.discovery function. If you specify a custom service object, then the credentials parameter will be ignored. + discovery_url: A URL template that points to the location of + the discovery service. It should have two parameters {api} and + {apiVersion} that when filled in produce an absolute URL to the + discovery document for that service. The environment variable + 'TPU_API_DISCOVERY_URL' will override this. Raises: ImportError: If the googleapiclient is not installed. @@ -123,8 +142,11 @@ class TPUClusterResolver(ClusterResolver): in_gke = self._inGke() # When using GKE with Cloud TPUs, the env variable will be set. - if tpu is None and in_gke: - tpu = self._gkeMaster() + if tpu is None: + if in_gke: + tpu = self._gkeEndpoints() + else: + tpu = self._envVarFallback() self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name @@ -149,14 +171,22 @@ class TPUClusterResolver(ClusterResolver): if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient must be installed before using the ' - 'TPU cluster resolver. Execute: `pip install ' - '--upgrade google-api-python-client` to install with ' - 'pip.') - - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) + raise ImportError('googleapiclient and oauth2client must be installed ' + 'before using the TPU cluster resolver. Execute: ' + '`pip install --upgrade google-api-python-client` ' + 'and `pip install --upgrade oauth2client` to ' + 'install with pip.') + + final_discovery_url = self._discoveryUrl() or discovery_url + if final_discovery_url: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials, + discoveryServiceUrl=final_discovery_url) + else: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials) else: self._service = service @@ -185,7 +215,7 @@ class TPUClusterResolver(ClusterResolver): ValueError: If none of the TPUs specified exists. """ if not self._shouldResolve(): - return self._tpu + return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] job_tasks = self.cluster_spec().job_tasks(self._job_name) if not job_tasks: @@ -227,6 +257,10 @@ class TPUClusterResolver(ClusterResolver): request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() + if 'state' in response and response['state'] != 'READY': + raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % + (self._tpu, response['state'])) + if 'health' in response and response['health'] != 'HEALTHY': raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, response['health'])) @@ -247,8 +281,12 @@ class TPUClusterResolver(ClusterResolver): # Case 3. return None # Case 2. - cluster_spec = {self._job_name: [self._tpu[len( - compat.as_bytes('grpc://')):]]} + cluster_spec = { + self._job_name: [ + x[len(compat.as_bytes('grpc://')):] + for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR)) + ] + } if self._coordinator_address: # {1, 2}.a diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5b3f9be5a11237f9dceebefa1db294efaf7e482d..ad4f6432630be44a7de6e778f55f1fb7fd66f307 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -158,6 +158,50 @@ class TPUClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testUnhealthyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'health': 'UNHEALTHY' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testNotReadyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'state': 'CREATING' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { @@ -358,15 +402,67 @@ class TPUClusterResolverTest(test.TestCase): compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master()) self.assertEqual(None, tpu_cluster_resolver.cluster_spec()) - def testGkeEnvironment(self): + def testGkeEnvironmentForDonut(self): os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' - self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ) + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) + self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + + def testGkeEnvironmentForPod(self): + os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470') + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() self.assertEqual( compat.as_bytes('grpc://10.120.27.5:8470'), - compat.as_bytes(TPUClusterResolver._gkeMaster())) + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + tasks { key: 1 value: '10.120.27.6:8470' } + tasks { key: 2 value: '10.120.27.7:8470' } + tasks { key: 3 value: '10.120.27.8:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + def testDiscoveryUrl(self): + os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' + self.assertEqual('https://{api}.internal/{apiVersion}', + TPUClusterResolver._discoveryUrl()) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 0708d6b7b9f0ba549aea091a265f42890e50d223..e524e9e7437b19e0d117fe7b85042e8154773a02 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -18,7 +18,16 @@ cmake_policy(SET CMP0022 NEW) # Options option(tensorflow_VERBOSE "Enable for verbose output" OFF) + +if(WIN32) +# BoringSSL is disabled for windows as it currently doesn't build with +# MSBuild. (Ninja is required.) option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF) +else() +# BoringSSL is enabled for gRPC. +option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" ON) +endif() + option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF) diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 693dc7cd673233b889b35a3f3170b57581da9a9f..b1e64aa55c80ad59cfdc0f4767c0282b4f73367f 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -20,6 +20,10 @@ set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) if(WIN32) + # We use unsecure gRPC because boringssl does not build on windows + set(grpc_TARGET grpc++_unsecure) + set(grpc_DEPENDS protobuf zlib) + set(grpc_SSL_PROVIDER NONE) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib @@ -32,9 +36,12 @@ if(WIN32) ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib) endif() else() + set(grpc_TARGET grpc++) + set(grpc_DEPENDS boringssl protobuf zlib) + set(grpc_SSL_PROVIDER module) set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) @@ -44,13 +51,13 @@ add_definitions(-DGRPC_ARES=0) ExternalProject_Add(grpc PREFIX grpc - DEPENDS protobuf zlib + DEPENDS ${grpc_DEPENDS} GIT_REPOSITORY ${GRPC_URL} GIT_TAG ${GRPC_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target ${grpc_TARGET} COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin INSTALL_COMMAND "" CMAKE_CACHE_ARGS @@ -59,7 +66,7 @@ ExternalProject_Add(grpc -DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS} -DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} - -DgRPC_SSL_PROVIDER:STRING=NONE + -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} ) # grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h. diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 6468bed4979253be5c20666d26bf24fa479d64a0..015cb73bbd93bb77f6748a364b263d99eb305c27 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -32,52 +32,13 @@ tensorflow/python/feature_column tensorflow/python/framework tensorflow/python/grappler tensorflow/python/keras -tensorflow/python/keras/activations tensorflow/python/keras/applications -tensorflow/python/keras/applications/densenet -tensorflow/python/keras/applications/inception_resnet_v2 -tensorflow/python/keras/applications/inception_v3 -tensorflow/python/keras/applications/mobilenet -tensorflow/python/keras/applications/nasnet -tensorflow/python/keras/applications/resnet50 -tensorflow/python/keras/applications/vgg16 -tensorflow/python/keras/applications/vgg19 -tensorflow/python/keras/applications/xception -tensorflow/python/keras/backend -tensorflow/python/keras/callbacks -tensorflow/python/keras/constraints tensorflow/python/keras/datasets -tensorflow/python/keras/datasets/boston_housing -tensorflow/python/keras/datasets/cifar10 -tensorflow/python/keras/datasets/cifar100 -tensorflow/python/keras/datasets/fashion_mnist -tensorflow/python/keras/datasets/imdb -tensorflow/python/keras/datasets/mnist -tensorflow/python/keras/datasets/reuters -tensorflow/python/keras/estimator -tensorflow/python/keras/initializers +tensorflow/python/keras/engine tensorflow/python/keras/layers -tensorflow/python/keras/losses -tensorflow/python/keras/metrics -tensorflow/python/keras/models -tensorflow/python/keras/optimizers tensorflow/python/keras/preprocessing -tensorflow/python/keras/preprocessing/image -tensorflow/python/keras/preprocessing/sequence -tensorflow/python/keras/preprocessing/text -tensorflow/python/keras/regularizers tensorflow/python/keras/utils tensorflow/python/keras/wrappers -tensorflow/python/keras/wrappers/scikit_learn -tensorflow/python/keras/_impl -tensorflow/python/keras/_impl/keras -tensorflow/python/keras/_impl/keras/applications -tensorflow/python/keras/_impl/keras/datasets -tensorflow/python/keras/_impl/keras/engine -tensorflow/python/keras/_impl/keras/layers -tensorflow/python/keras/_impl/keras/preprocessing -tensorflow/python/keras/_impl/keras/utils -tensorflow/python/keras/_impl/keras/wrappers tensorflow/python/kernel_tests tensorflow/python/kernel_tests/boosted_trees tensorflow/python/kernel_tests/distributions @@ -100,6 +61,7 @@ tensorflow/python/summary tensorflow/python/summary/writer tensorflow/python/tools tensorflow/python/training +tensorflow/python/training/checkpointable tensorflow/python/user_ops tensorflow/python/util tensorflow/python/util/protobuf @@ -153,6 +115,8 @@ tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler tensorflow/contrib/constrained_optimization tensorflow/contrib/constrained_optimization/python +tensorflow/contrib/control_flow +tensorflow/contrib/control_flow/python tensorflow/contrib/copy_graph tensorflow/contrib/copy_graph/python tensorflow/contrib/copy_graph/python/util @@ -333,6 +297,8 @@ tensorflow/contrib/metrics tensorflow/contrib/metrics/python tensorflow/contrib/metrics/python/metrics tensorflow/contrib/metrics/python/ops +tensorflow/contrib/mixed_precision +tensorflow/contrib/mixed_precision/python tensorflow/contrib/mpi_collectives/python tensorflow/contrib/mpi_collectives/python/ops tensorflow/contrib/model_pruning diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index d63c41db844af243f0c6600b1565635ac9b91cac..cf1ee2ad76f2cc9f58dbe90182a3e17f1edc7ed3 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -11,7 +11,6 @@ tensorflow/contrib/mpi tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle tensorflow/contrib/tensor_forest/proto -tensorflow/contrib/tensorboard/graph_explorer/proto tensorflow/contrib/tensorboard/plugins/projector tensorflow/contrib/tensorboard/plugins/trace tensorflow/contrib/tpu/proto diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index c6a15f2ca075c8de96786a580c7ddb89541df5bc..2e0a2fcef4cbdc50f0521296c4a25a864dbd8b77 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -21,9 +21,8 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" + "${tensorflow_source_dir}/tensorflow/c/eager/c_api_debug.cc" "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h" "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc" @@ -38,13 +37,15 @@ add_dependencies( tf_core_lib tf_protos_cc) -add_library(tf_c_python_api OBJECT - "${tensorflow_source_dir}/tensorflow/c/python_api.cc" - "${tensorflow_source_dir}/tensorflow/c/python_api.h" -) -add_dependencies( - tf_c_python_api - tf_c - tf_core_lib - tf_core_framework - tf_protos_cc) +if(tensorflow_BUILD_PYTHON_BINDINGS) + add_library(tf_c_python_api OBJECT + "${tensorflow_source_dir}/tensorflow/c/python_api.cc" + "${tensorflow_source_dir}/tensorflow/c/python_api.h" + ) + add_dependencies( + tf_c_python_api + tf_c + tf_core_lib + tf_core_framework + tf_protos_cc) +endif() diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index f73da0b8ab18af1eca4c2bd577604595f8b8ec6d..6c90cf398c69c8c1b22ea75e0c407f258e2535f9 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -155,7 +155,7 @@ if (WIN32) set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib") endif() else (WIN32) - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so") + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX}") endif (WIN32) add_custom_target(tf_extension_ops) diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index b47c32f1c48b3d42fe5b4ba115cc2a511b7ee5f4..dac84ccb0dbf4848329e35a6e9bcf6213d8c0e55 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -213,10 +213,6 @@ else() list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude}) endif() -file(GLOB tf_core_platform_exclude_srcs - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc") -list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_exclude_srcs}) - list(APPEND tf_core_lib_srcs ${tf_core_platform_srcs}) if(UNIX) @@ -286,8 +282,6 @@ set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.c file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/framework/*.h" "${tensorflow_source_dir}/tensorflow/core/framework/*.cc" - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.h" - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.h" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 1505d3e2083b5a3446a7f85d59c73816e65e1a2a..2d76bf530a2100b2afa80a16a5d64b6ec51ffc68 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -68,6 +68,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index e558691de4b74988031f7b2204aad92e8c7af68b..bc753333dba4f67eee0114c4022743dd59a05982 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -113,6 +113,7 @@ GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_stats "${tensorflow_source_dir}/tensor GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}") GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(gcs_config "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/gcs_config_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(reduce_slice_ops "${tensorflow_source_dir}/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc") ######################################################## diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index c4bdb69d828b269e6246777e74c3756ba1c4b96f..92446044892127284ecb8753a250b77cb2a5743a 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -244,13 +244,11 @@ add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD # tf_python_op_gen_main library ######################################################## set(tf_python_op_gen_main_srcs - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" - "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" - "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" ) add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs}) @@ -422,6 +420,8 @@ GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py) GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_gcs_config_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_gcs_config_ops.py) GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" @@ -464,12 +464,12 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.cc" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.h" "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.h" @@ -715,7 +715,7 @@ if(WIN32) endif() else() add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.so) endif() @@ -725,7 +725,7 @@ endif() ######################################################## # Parse tensorflow/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/BUILD api_generator_BUILD_text) +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) @@ -736,7 +736,7 @@ foreach(api_init_file ${api_init_files_list}) string(STRIP "${api_init_file}" api_init_file) if(api_init_file) string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes - list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/${api_init_file}") + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/${api_init_file}") endif() endforeach(api_init_file) set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt") @@ -749,18 +749,16 @@ add_custom_command( # tensorflow/__init__.py depends on files generated in this step. So, remove it while # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py - COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py # Run create_python_api.py to generate API init files. COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" "${api_init_list_file}" - - # Re-add tensorflow/__init__.py back. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" COMMENT "Generating __init__.py files for Python API." WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" @@ -769,7 +767,49 @@ add_custom_command( add_custom_target(tf_python_api SOURCES ${api_init_files}) add_dependencies(tf_python_api tf_python_ops) +# TODO(mikecase): This can be removed once tf.estimator is moved +# out of TensorFlow. +######################################################## +# Generate API __init__.py files for tf.estimator. +######################################################## +# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text}) +string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "," ";" api_init_files_list ${api_init_files_text}) + +set(api_init_files "") +foreach(api_init_file ${api_init_files_list}) + string(STRIP "${api_init_file}" api_init_file) + if(api_init_file) + string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api/${api_init_file}") + endif() +endforeach(api_init_file) +set(estimator_api_init_list_file "${tensorflow_source_dir}/estimator_api_init_files_list.txt") +file(WRITE "${estimator_api_init_list_file}" "${api_init_files}") + +# Run create_python_api.py to generate __init__.py files. +add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api" + "--package=tensorflow.python.estimator" + "--apiname=estimator" + "${estimator_api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" +) + +add_custom_target(estimator_python_api SOURCES ${api_init_files}) +add_dependencies(estimator_python_api tf_python_ops) ############################################################ # Build a PIP package containing the TensorFlow runtime. ############################################################ @@ -780,6 +820,7 @@ add_dependencies(tf_python_build_pip_package tf_python_touchup_modules tf_python_ops tf_python_api + estimator_python_api tf_extension_ops) # Fix-up Python files that were not included by the add_python_module() macros. @@ -791,7 +832,6 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/testing/python/framework/util_test.py ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/testing/python/framework/) - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/tools/pip_package/README ${CMAKE_CURRENT_BINARY_DIR}/tf_python/) diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 5942ff3363a96de70df7e13d0857e4ad82e35fee..eb9482dc25f2be8ce46cc38bf3dd28889b09a9d4 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -212,6 +212,10 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py" # Disable following manual tag in BUILD. "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py" + # These tests depend on a .so file + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py ) if (WIN32) diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index cffe069aa352f8a6f2c436bc70b62f54e2336ac6..4f957f1e0b46fde5daacbc59657af994e13c42d5 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -44,7 +44,8 @@ UNDNAME = "undname.exe" DUMPBIN = "dumpbin.exe" # Exclude if matched -EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::") +EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::|Internal|" + r"python_op_gen_internal|grappler") # Include if matched before exclude INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|" @@ -56,6 +57,10 @@ INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|" r"tensorflow::ops::internal::Enter|" r"tensorflow::strings::internal::AppendPieces|" r"tensorflow::strings::internal::CatPieces|" + r"tensorflow::errors::Internal|" + r"tensorflow::Tensor::CopyFromInternal|" + r"tensorflow::kernel_factory::" + r"OpKernelRegistrar::InitInternal|" r"tensorflow::io::internal::JoinPathImpl") # Include if matched after exclude @@ -64,7 +69,7 @@ INCLUDE_RE = re.compile(r"^(TF_\w*)$|" r"tensorflow::|" r"functor::|" r"\?nsync_|" - r"perftools::gputools") + r"stream_executor::") # We want to identify data members explicitly in the DEF file, so that no one # can implicitly link against the DLL if they use one of the variables exported diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc index ae4d9d2836a0f89a9765004a85bc3c292b0e484f..81b36ca902b82220d9c5282a1ec72324a6d95922 100644 --- a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py index f039cb0f5265b920200f63c5bd5ebeb4e23826be..0fbe3081af0b4de7f116918b3f49efe91a2d83bd 100644 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck.py +++ b/tensorflow/contrib/coder/python/layers/entropybottleneck.py @@ -28,7 +28,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import engine +from tensorflow.python.keras import engine from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import init_ops diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index b2f678fb29cedd3ec32f0460354cc4ac18fb63d3..a56a01b16356e12b83344474c7fbe427530f0c74 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -24,7 +24,6 @@ from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -170,7 +169,6 @@ class JITTest(test.TestCase): self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) -@test_util.with_c_api class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): diff --git a/tensorflow/contrib/control_flow/BUILD b/tensorflow/contrib/control_flow/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..e8036d63aeeac224b226899c036416a06b4ffe65 --- /dev/null +++ b/tensorflow/contrib/control_flow/BUILD @@ -0,0 +1,53 @@ +# New implementations of control flow ops + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +py_library( + name = "control_flow", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":cond_v2", + ], +) + +py_library( + name = "cond_v2", + srcs = ["python/cond_v2.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:c_api_util", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:function_def_to_graph", + "//tensorflow/python:functional_ops_gen", + "//tensorflow/python:gradients", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:util", + ], +) + +tf_py_test( + name = "cond_v2_test", + size = "small", + srcs = ["python/cond_v2_test.py"], + additional_deps = [ + ":cond_v2", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:training", + ], + grpc_enabled = True, +) diff --git a/tensorflow/python/keras/_impl/keras/wrappers/__init__.py b/tensorflow/contrib/control_flow/__init__.py similarity index 68% rename from tensorflow/python/keras/_impl/keras/wrappers/__init__.py rename to tensorflow/contrib/control_flow/__init__.py index 20c95929e3d2e1f66e66efe43b9685c5d6ed1c10..582af2cf10a3d92dd8611b0f2826625e3acfb099 100644 --- a/tensorflow/python/keras/_impl/keras/wrappers/__init__.py +++ b/tensorflow/contrib/control_flow/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,11 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras API wrappers. + +"""New implementations of TF control flow ops. + +@@cond_v2 """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.wrappers import scikit_learn +# pylint: disable=unused-import +from tensorflow.contrib.control_flow.python.cond_v2 import cond_v2 +# pylint: enable=unused-import + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9ffad9caa92d2d3be8f598758a443b0eceb8d4d8 --- /dev/null +++ b/tensorflow/contrib/control_flow/python/cond_v2.py @@ -0,0 +1,429 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""cond_v2 and gradient. + +This is a version of cond that emits a single If op, as well as the gradient +function for If ops produced by cond_v2. This will eventually replace the +current tf.cond implementation once it reaches feature and performance parity. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python.framework import c_api_util +from tensorflow.python.framework import function +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_functional_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.util import compat + + +# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify +# that they aren't part of the official public API. These protected members +# often need to be used by implementation code however. Rather than litter the +# code with pylint comments, we ignore protected access violations for +# readability. +# pylint: disable=protected-access + + +def cond_v2(pred, true_fn, false_fn, name="cond"): + """Like tf.cond, except emits a single If op.""" + with ops.name_scope(name) as scope: + true_graph = function.func_graph_from_py_func(true_fn, [], [], + name="%s_true" % scope) + false_graph = function.func_graph_from_py_func(false_fn, [], [], + name="%s_false" % scope) + _check_same_outputs(true_graph, false_graph) + + # Add inputs to true_graph and false_graph to make them match. Note that + # this modifies true_graph and false_graph. + cond_inputs = _make_inputs_match(true_graph, false_graph, + true_graph.extra_inputs, + false_graph.extra_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # the gradient computation. + + true_intermediates = _get_intermediates(true_graph) + false_intermediates = _get_intermediates(false_graph) + + # Save the original number of outputs to return to the caller. + num_cond_outputs = len(true_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_outputs, extra_false_outputs = _pad_params( + true_graph, false_graph, true_intermediates, false_intermediates) + + true_graph.outputs.extend(extra_true_outputs) + false_graph.outputs.extend(extra_false_outputs) + + # Create the If op. + tensors = gen_functional_ops._if( + pred, cond_inputs, [t.dtype for t in true_graph.outputs], + _create_new_tf_function(true_graph), + _create_new_tf_function(false_graph), + name=scope) + + return tensors[:num_cond_outputs] + + +@ops.RegisterGradient("If") +def _IfGrad(op, *grads): # pylint: disable=invalid-name + """The gradient of an If op produced by cond_v2.""" + true_graph, false_graph = _get_func_graphs(op) + + # Create grad functions that compute the gradient of the true/false forward + # graphs. These functions will capture tensors from the forward pass + # functions. + true_grad_graph = _create_grad_func( + true_graph, grads, _get_grad_fn_name(true_graph)) + false_grad_graph = _create_grad_func( + false_graph, grads, _get_grad_fn_name(false_graph)) + + assert ([t.dtype for t in true_grad_graph.outputs] == + [t.dtype for t in false_grad_graph.outputs]) + + # Match up the captured grad function inputs with outputs of 'op' and other + # external tensors. + true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph) + false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph) + + # Make the inputs to true_grad_graph and false_grad_graph match. Note that + # this modifies true_grad_graph and false_grad_graph. + grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, + true_grad_inputs, false_grad_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # higher-order gradient computations. + + true_grad_intermediates = _get_intermediates(true_grad_graph) + false_grad_intermediates = _get_intermediates(false_grad_graph) + + # Save the original number of gradient outputs to return. + num_grad_outputs = len(true_grad_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( + true_grad_graph, false_grad_graph, + true_grad_intermediates, false_grad_intermediates) + + true_grad_graph.outputs.extend(extra_true_grad_outputs) + false_grad_graph.outputs.extend(extra_false_grad_outputs) + + # Create the gradient If op. + tensors = gen_functional_ops._if( + op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], + _create_new_tf_function(true_grad_graph), + _create_new_tf_function(false_grad_graph)) + + # The predicate has no gradient. + return [None] + tensors[:num_grad_outputs] + + +def _get_func_graphs(if_op): + """Returns `_FuncGraph`s for the input op branches. + + Args: + if_op: The _If Operation. + + Returns: + A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch. + """ + def _get_func_graph_for_branch(branch_name): + extra_inputs = if_op.inputs[1:] # First input is pred. + input_shapes = [t.shape for t in extra_inputs] + func_name = if_op.get_attr(branch_name).name + fdef = if_op.graph._get_function(func_name).definition + func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) + func_graph.extra_inputs = extra_inputs + func_graph.extra_args = func_graph.inputs + func_graph._captured = dict(zip(extra_inputs, func_graph.inputs)) + return func_graph + + return (_get_func_graph_for_branch("then_branch"), + _get_func_graph_for_branch("else_branch")) + + +def _grad_fn(func_graph, grads): + """The gradient function for each conditional branch. + + This function builds the gradient graph of the corresponding forward-pass + conditional branch in `func_graph`. This is done by differentiating + func_graph's outputs w.r.t. its inputs. + + Args: + func_graph: function._FuncGraph. The corresponding forward-pass function. + grads: The list of input gradient Tensors. + + Returns: + The output gradient Tensors. + """ + # Filter out untrainable function outputs. + # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes + # cause _GradientsHelper to raise an exception (e.g. the implementation + # doesn't expect 'ys' to contain boolean tensors). + assert len(func_graph.outputs) == len(grads) + ys = [] + grad_ys = [] + for y, grad_y in zip(func_graph.outputs, grads): + if not gradients_impl._IsTrainable(y): + continue + ys.append(y) + grad_ys.append(grad_y) + + # Build the gradient graph. Note that this builds the gradient computation of + # func_graph in the current graph, which requires capturing tensors from + # func_graph. The captured func_graph tensors are resolved to external tensors + # in _get_grad_inputs. + result = gradients_impl._GradientsHelper( + ys, func_graph.inputs, grad_ys=grad_ys, + src_graph=func_graph) + + # Functions can't return None; replace Nones with zero tensors. + # TODO(b/80444525): don't return anything here and make _IfGrad return None if + # both branches have zero gradient. + for i in range(len(result)): + if result[i] is None: + result[i] = array_ops.zeros_like(func_graph.inputs[i]) + + return result + + +def _create_grad_func(func_graph, grads, name): + """Returns the _FuncGraph representation of _grad_fn.""" + return function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), + [], [], name) + + +def _get_grad_inputs(if_op, cond_graph, grad_graph): + """Returns the tensors we should pass to grad_graph. + + This method handles tensors captured from cond_graph in grad_graph. It + converts these to suitable input tensors from the outer graph. + + Args: + if_op: Operation. The forward-pass If op that uses cond_graph. + cond_graph: function._FuncGraph. The forward-pass function. + grad_graph: function._FuncGraph. The gradients function. + + Returns: + A list of inputs tensors to be passed to grad_graph. + """ + inputs = [] + + # Maps placeholders in cond_graph -> input tensor in outer graph. + forward_input_map = {v: k for k, v in cond_graph._captured.items()} + + for t in grad_graph.extra_inputs: + if t.graph == ops.get_default_graph(): + # t is in the outer graph (e.g. one of the input gradients). + inputs.append(t) + elif t in forward_input_map: + # t is an input placeholder in cond_graph. Get the corresponding input + # tensor in the outer graph. + assert t.graph == cond_graph + assert forward_input_map[t].graph == ops.get_default_graph() + inputs.append(forward_input_map[t]) + else: + # t is an intermediate value in cond_graph. Get the corresponding output + # of 'if_op' (note that all intermediate values are outputs). + assert t.graph == cond_graph + output_idx = cond_graph.outputs.index(t) + inputs.append(if_op.outputs[output_idx]) + + return inputs + + +def _create_new_tf_function(func_graph): + """Converts func_graph to a TF_Function and adds it to the current graph. + + Args: + func_graph: function._FuncGraph + + Returns: + The name of the new TF_Function. + """ + c_func = c_api.TF_GraphToFunction_wrapper( + func_graph._c_graph, + compat.as_str(func_graph.name), + False, # append_hash_to_fn_name + None, # opers + [t._as_tf_output() for t in func_graph.inputs], + [t._as_tf_output() for t in func_graph.outputs], + [], + None, # opts + None) # description + _ = c_api_util.ScopedTFFunction(c_func) + + # TODO(b/109833212): this sucks, we're serializing the TF_Function*, + # deserializing it into a Python FunctionDef, then reserializing it to create + # a new TF_Function that we add to the graph. + fdef = function.function_def_from_tf_function(c_func) + defined_func = function._from_definition(fdef) + defined_func.add_to_graph(ops.get_default_graph()) + + return func_graph.name + + +def _get_intermediates(func_graph): + """Returns all tensors in `func_graph` that aren't inputs or outputs.""" + intermediates = [] + for op in func_graph.get_operations(): + for t in op.outputs: + if t in func_graph.inputs: continue + if t in func_graph.outputs: continue + intermediates.append(t) + return intermediates + + +def _separate_unique_inputs(true_inputs, false_inputs): + """Separates tensors appearing only in true_inputs or false_inputs, or both. + + Args: + true_inputs: list of Tensors + false_inputs: list of Tensors + + Returns: + Three lists of Tensors: + 1. The tensors that appear in both true_inputs and false_inputs + 2. The tensors that only appear in true_inputs + 3. The tensors that only appear in false_inputs + """ + true_inputs = set(true_inputs) + false_inputs = set(false_inputs) + + shared_inputs = true_inputs.intersection(false_inputs) + true_only_inputs = true_inputs - false_inputs + false_only_inputs = false_inputs - true_inputs + + return list(shared_inputs), list(true_only_inputs), list(false_only_inputs) + + +def _pad_params(true_graph, false_graph, true_params, false_params): + """Returns new param lists that have matching signatures. + + This is done by mirroring each param list in the other using dummy params. + There is no merging of params. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_params: a list of Tensors from true_graph + false_params: a list of Tensors from false_graph + + Returns: + A new list of Tensors in true_graph and a new list of Tensors in + false_graph. The two lists have the same number of Tensors, with matching + types and shapes across the lists. + """ + new_true_params = (true_params + + _create_dummy_params(true_graph, false_params)) + new_false_inputs = (_create_dummy_params(false_graph, true_params) + + false_params) + return new_true_params, new_false_inputs + + +def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): + """Modifies true_graph and false_graph so they have the same input signature. + + This method reorders and/or adds parameters to true_graph and false_graph so + they have the same input signature, and updates the 'inputs', 'extra_inputs', + and '_captured' fields of both graphs accordingly. It uses the input tensors + from the outer graph to avoid duplicating shared arguments. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_inputs: a list of Tensors in the outer graph. The inputs for + true_graph. + false_inputs: a list of Tensors in the outer graph. The inputs for + false_graph. + + Returns: + A new list of Tensors from the outer graph that are the new inputs for both + true_graph and false_graph. This is a deduped version of true_inputs + + false_inputs. + """ + shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs( + true_inputs, false_inputs) + + new_inputs = shared_inputs + true_only_inputs + false_only_inputs + + true_input_to_param = dict(zip(true_inputs, true_graph.inputs)) + false_input_to_param = dict(zip(false_inputs, false_graph.inputs)) + + true_graph.inputs = ( + [true_input_to_param[t] for t in shared_inputs] + + [true_input_to_param[t] for t in true_only_inputs] + + _create_dummy_params(true_graph, false_only_inputs)) + + false_graph.inputs = ( + [false_input_to_param[t] for t in shared_inputs] + + _create_dummy_params(false_graph, true_only_inputs) + + [false_input_to_param[t] for t in false_only_inputs]) + + # Rewrite the _FuncGraphs' state to reflect the new inputs. + true_graph.extra_inputs = new_inputs + false_graph.extra_inputs = new_inputs + + true_graph._captured = dict(zip(new_inputs, true_graph.inputs)) + false_graph._captured = dict(zip(new_inputs, false_graph.inputs)) + + return new_inputs + + +def _create_dummy_params(func_graph, template_tensors): + """Creates tensors in func_graph to represent template_tensors. + + Args: + func_graph: function._FuncGraph. + template_tensors: a list of tensors in the outer graph. + + Returns: + A list of tensors in func_graph. + """ + with func_graph.as_default(): + return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape) + for t in template_tensors] + + +def _get_grad_fn_name(func_graph): + """Returns a unique name to use for the grad function of `func_graph`.""" + name = "%s_grad" % func_graph.name + + base_name = name + counter = 1 + if ops.get_default_graph()._is_function(name): + name = "%s_%s" % (base_name, counter) + counter += 1 + + return name + + +def _check_same_outputs(true_graph, false_graph): + """Raises an error if true_graph and false_graph have different outputs.""" + true_output_types = [t.dtype for t in true_graph.outputs] + false_output_types = [t.dtype for t in false_graph.outputs] + if (len(true_graph.outputs) != len(false_graph.outputs) or + true_output_types != false_output_types): + raise ValueError( + "true_fn() and false_fn() must return the same number and type of " + "arguments, got:\n" + " true_fn: %s\n" + " false_fn: %s" % (true_output_types, false_output_types)) diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/contrib/control_flow/python/cond_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..338601aa2c5ee8ffc97aa2e07ff05a2d17531936 --- /dev/null +++ b/tensorflow/contrib/control_flow/python/cond_v2_test.py @@ -0,0 +1,171 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for cond_v2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.control_flow.python import cond_v2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver + + +class NewCondTest(test.TestCase): + + def _testCond(self, true_fn, false_fn, train_vals): + pred = array_ops.placeholder(dtypes.bool, name="pred") + + expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") + actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") + + expected_grad = gradients_impl.gradients(expected, train_vals) + actual_grad = gradients_impl.gradients(actual, train_vals) + + with self.test_session() as sess: + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: True}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: False}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + def testBasic(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * 2.0 + + def false_fn(): + return y * 3.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testBasic2(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * y * 2.0 + + def false_fn(): + return 2.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testNoInputs(self): + pred = array_ops.placeholder(dtypes.bool, name="pred") + + def true_fn(): + return constant_op.constant(1.0) + + def false_fn(): + return constant_op.constant(2.0) + + out = cond_v2.cond_v2(pred, true_fn, false_fn) + + with self.test_session() as sess: + self.assertEqual(sess.run(out, {pred: True}), [1.0]) + self.assertEqual(sess.run(out, {pred: False}), [2.0]) + + def testSecondDerivative(self): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + cond_grad = gradients_impl.gradients(cond, [x]) + cond_grad_grad = gradients_impl.gradients(cond_grad, [x]) + + with self.test_session() as sess: + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + def testGradientOfDeserializedCond(self): + with ops.Graph().as_default(): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + ops.add_to_collection("x", x) + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + ops.add_to_collection("pred", pred) + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + for c in cond: + ops.add_to_collection("cond", c) + meta_graph = saver.export_meta_graph() + + with ops.Graph().as_default() as g: + saver.import_meta_graph(meta_graph) + x = ops.get_collection("x")[0] + pred = ops.get_collection("pred")[0] + cond = ops.get_collection("cond") + cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad") + cond_grad_grad = gradients_impl.gradients( + cond_grad, [x], name="cond_grad_grad") + with self.test_session(graph=g) as sess: + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 102bc460fdadb0ad5dc9a2960b8655c55357108e..a0dd3881a86c19e47ccb65f84a2477a55626b81c 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -218,7 +218,6 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): new_control_inputs, input_types, new_original_op, op_def) #Use Graph's hidden methods to add the op - to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 33ddfb8dee1c446f22c7d0071f9a0e2bbac6bdad..8285ea04926d3a24e9c22bd6d69eb7a48f5e3a85 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -54,11 +54,11 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import adagrad from tensorflow.python.training import adam -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training.checkpointable import util as checkpointable_utils CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 73a961992e19fabec5d0f75be1b52dbba20eb7af..8822a7523f6b168f569e29970c9c29f2eb3614fc 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -20,11 +20,10 @@ from __future__ import print_function import os from tensorflow.contrib.checkpoint.python import split_dependency from tensorflow.contrib.rnn.python.ops import lstm_ops -from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_cudnn_rnn_ops from tensorflow.python.ops import init_ops @@ -33,8 +32,8 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.training import checkpointable as checkpointable_lib from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import base as checkpointable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" @@ -1647,10 +1646,3 @@ class CudnnRNNRelu(_CudnnRNNNoInputC): # 1 set of weight and bias parameters for the recurrent input, and 1 for the # previous layer input. _NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER - - -ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNN")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNBackprop")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 077cbba9d2ae41a83f6c358a63ae27aec5741e2c..1af1ed08b53ee04367eb316d5c9caa0216f2e88d 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -23,11 +23,14 @@ removing existing functionality. See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Counter +@@CheckpointInputPipelineHook +@@CsvDataset @@SqlDataset @@assert_element_shape @@batch_and_drop_remainder @@bucket_by_sequence_length +@@choose_from_datasets @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window @@ -72,8 +75,10 @@ from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave +from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.readers import CsvDataset from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import make_csv_dataset from tensorflow.contrib.data.python.ops.readers import read_batch_features diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index c56910c7833d4c54fa8db27cd061b404013f3f54..7b69e10441eba3e38c979d5715c16699ac2710ed 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -29,6 +29,16 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "csv_dataset_op", + srcs = ["csv_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + cc_library( name = "ignore_errors_dataset_op", srcs = ["ignore_errors_dataset_op.cc"], @@ -63,6 +73,7 @@ cc_library( cc_library( name = "dataset_kernels", deps = [ + ":csv_dataset_op", ":directed_interleave_dataset_op", ":ignore_errors_dataset_op", ":prefetching_kernels", diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4657807785d58727d34f37172bd30c56a5b7cde6 --- /dev/null +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -0,0 +1,734 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/parsing_ops.cc. +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +namespace tensorflow { +namespace { + +class CSVDatasetOp : public DatasetOpKernel { + public: + explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + OpInputList record_defaults_list; + OP_REQUIRES_OK(ctx, + ctx->input_list("record_defaults", &record_defaults_list)); + for (int i = 0; i < record_defaults_list.size(); ++i) { + OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, + errors::InvalidArgument( + "There should only be 1 default per field but field ", i, + " has ", record_defaults_list[i].NumElements())); + } + + const Tensor* select_cols_tensor; + OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor)); + OP_REQUIRES(ctx, select_cols_tensor->dims() == 1, + errors::InvalidArgument("`select_cols` must be a vector.")); + + int64 buffer_size; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "buffer_size", &buffer_size)); + OP_REQUIRES(ctx, buffer_size > 0, + errors::InvalidArgument("buffer_size should be positive")); + + string delim; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "field_delim", &delim)); + OP_REQUIRES(ctx, delim.size() == 1, + errors::InvalidArgument("field_delim should be only 1 char")); + + bool header; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "header", &header)); + + bool use_quote_delim; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "use_quote_delim", + &use_quote_delim)); + string na_value; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "na_value", &na_value)); + + std::vector record_defaults; + record_defaults.reserve(record_defaults_list.size()); + for (const Tensor& t : record_defaults_list) { + record_defaults.push_back(t); + } + + std::vector filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat()(i)); + } + + std::vector select_cols; + select_cols.reserve(select_cols_tensor->NumElements()); + for (int i = 0; i < select_cols_tensor->NumElements(); ++i) { + select_cols.push_back(select_cols_tensor->flat()(i)); + } + OP_REQUIRES( + ctx, output_types_.size() == select_cols.size() || select_cols.empty(), + errors::InvalidArgument("select_cols should match output size")); + for (int i = 1; i < select_cols.size(); i++) { + OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i], + errors::InvalidArgument( + "select_cols should be strictly increasing indices")); + } + OP_REQUIRES( + ctx, select_cols.empty() || select_cols.front() >= 0, + errors::InvalidArgument("select_cols should be non-negative indices")); + + *output = new Dataset(ctx, std::move(filenames), header, buffer_size, + output_types_, output_shapes_, + std::move(record_defaults), std::move(select_cols), + use_quote_delim, delim[0], std::move(na_value)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, std::vector filenames, bool header, + int64 buffer_size, const DataTypeVector& output_types, + const std::vector& output_shapes, + std::vector record_defaults, std::vector select_cols, + bool use_quote_delim, char delim, string na_value) + : GraphDatasetBase(ctx), + filenames_(std::move(filenames)), + header_(header), + buffer_size_(buffer_size), + out_type_(output_types), + output_shapes_(output_shapes), + record_defaults_(std::move(record_defaults)), + select_cols_(std::move(select_cols)), + use_quote_delim_(use_quote_delim), + delim_(delim), + na_value_(std::move(na_value)) {} + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::CSV")})); + } + + const DataTypeVector& output_dtypes() const override { return out_type_; } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { return "CSVDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + // TODO(rachelim): Implement this + std::vector input_tensors; + TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); + return errors::Unimplemented("CSVDataset: AsGraphDefInternal"); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + bool select_all = dataset()->select_cols_.empty(); + do { + // We are currently processing a file, so try to read the next record + if (input_stream_) { + Status s = ReadRecord(ctx, out_tensors, select_all, + dataset()->select_cols_); + if (s.ok()) { + // Validate output + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument( + "Expect ", dataset()->out_type_.size(), " fields but have ", + out_tensors->size(), " in record"); + } + + *end_of_sequence = false; + return s; + } + if (!errors::IsOutOfRange(s)) { + // Not at the end of file, return OK or non-EOF errors to caller. + *end_of_sequence = false; + return s; + } + // We have reached the end of the current file, so maybe + // move on to next file. + ResetStreamsLocked(); + ++current_file_index_; + } + // Iteration ends when there are no more files to process. + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + // TODO(rachelim): Implement save + return errors::Unimplemented("CSVDataset: SaveInternal"); + } + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + // TODO(rachelim): Implement restore + return errors::Unimplemented("CSVDataset: RestoreInternal"); + } + + private: + // Reads an entire CSV row from the input stream, either from the + // existing buffer or by filling the buffer as needed. Converts extracted + // fields to output tensors as we go. + // + // When this function is called, pos_ should be the index of the first + // character of the record in buffer_, or past the end of the buffer. + // Note: ctx and out_tensors are only used in this function + // when fields are included in the record. + Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors, + bool select_all, const std::vector& selected) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // At the end of the file, this will return errors::OutOfRange + TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); + pos_ = 0; + } + + // The first character may be \n if this is the continuation of a + // \r\n linebreak between this and the previous record. If so, skip it. + + bool end_of_record = false; // Keep track of when we find \n, \r or EOF + size_t num_parsed = 0; + size_t num_selected_parsed = 0; + + Status result; + + while (!end_of_record) { // Read till we reach \n, \r or EOF + bool include = + select_all || (num_selected_parsed < selected.size() && + selected[num_selected_parsed] == num_parsed); + + // Don't fail fast, so that the next call to GetNext may still return + // a valid record + result.Update( + ParseOneField(ctx, out_tensors, &end_of_record, include)); + + num_parsed++; + if (include) num_selected_parsed++; + } + + return result; + } + + // Parses one field from position pos_ in the buffer. Fields are + // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of + // the next field. + Status ParseOneField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // If we get here, this means the previous field's end coincided + // with the end of the buffer. We can fill the buffer without abandon. + Status s = FillBuffer(&buffer_); + + if (errors::IsOutOfRange(s)) { + // Reached EOF, and last field is empty + *end_of_record = true; + if (include) { + return FieldToOutput(ctx, StringPiece(), out_tensors); + } else { + return Status::OK(); + } + } else if (!s.ok()) { + return s; // Surface other errors back to caller + } + + pos_ = 0; + } + + if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { + return ParseQuotedField(ctx, out_tensors, end_of_record, include); + } + + return ParseUnquotedField(ctx, out_tensors, end_of_record, include); + } + + // For keeping track of relevant parts of a field from a previous buffer + struct Piece { + size_t start; + size_t len; + string buffer; + + Piece(string buffer, size_t start, size_t len) + : start(start), len(len), buffer(std::move(buffer)) {} + }; + + // Given that pos_ exceeds the buffer, saves the relevant part of the + // current buffer (if necessary), fills the buffer, and resets indices to + // 0. + Status SaveAndFillBuffer(std::vector* earlier_pieces, + size_t* start, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + string temp_buffer; + + buffer_.swap(temp_buffer); + if (include && pos_ > *start) { + earlier_pieces->push_back( + Piece(std::move(temp_buffer), *start, pos_ - *start)); + } + pos_ = 0; + *start = 0; + return FillBuffer(&buffer_); + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseQuotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + pos_++; // Starting quotation mark + + Status parse_result; + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument( + "Reached end of file without closing quoted field in " + "record"); + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + if (ch == '"') { + // When we encounter a quote, we look ahead to the next character to + // decide what to do + pos_++; + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + // This was the last field. We are done + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(), out_tensors, earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; + } + } + + char next = buffer_[pos_]; + pos_++; + if (next == dataset()->delim_) { + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + return parse_result; + + } else if (next == '\n' || next == '\r') { + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + if (next == '\r') SkipNewLineIfNecessary(); + return parse_result; + } else if (next != '"') { + // Take note of the error, but keep going to end of field. + include = false; // So we don't get funky errors when trying to + // unescape the quotes. + parse_result.Update(errors::InvalidArgument( + "Quote inside a string has to be escaped by another quote")); + } + + } else { + pos_++; + } + } + } + + // Converts quoted field to an output tensor, removing the starting + // and ending quotes from it and unescaping double quotations if + // necessary. + Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + if (field.find('\"', 1) == field.size() - 1) { + // `field` contains no escaped quotation marks. + // Exclude framing quotation marks + field.remove_prefix(1); + field.remove_suffix(1); + return FieldToOutput(ctx, field, out_tensors); + } + } + string field_complete; + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + field_complete.reserve(str_len); + + // This bool flips every time we see a quote, so that we skip the second + // quote of every pair of adjacent quotes in the field. We need to track + // this across iterations of the for loop because adjacent double quotes + // may be in different buffers. Initialize to true because we also skip + // the opening quotation mark of the quoted field. + bool skip_next_quote = true; + for (const Piece& p : earlier_pieces) { + AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), + &field_complete, &skip_next_quote); + } + AppendUnescapedPiece(field, &field_complete, &skip_next_quote); + StringPiece result = StringPiece(field_complete); + result.remove_suffix(1); // Skip final quote + + return FieldToOutput(ctx, result, out_tensors); + } + + void AppendUnescapedPiece(StringPiece piece, string* field_complete, + bool* skip_next_quote) { + size_t from = 0; + size_t found = piece.find('\"', from); + while (found != string::npos) { + if (!*skip_next_quote) { + // This is the first quote in a pair of adjacent double quotes + field_complete->append(piece.data() + from, found + 1 - from); + } + *skip_next_quote = !*skip_next_quote; + from = found + 1; + found = piece.find('\"', from); + } + // Include the chunk after the last quotation mark in the string + if (from < piece.size()) { + field_complete->append(piece.data() + from, piece.size() - from); + } + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseUnquotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + Status parse_result; + + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + // Handle errors + if (errors::IsOutOfRange(s)) { + // Whatever we have is the last field of the last record + *end_of_record = true; + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + + if (ch == dataset()->delim_) { + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + pos_++; + return parse_result; + } + if (ch == '\n' || ch == '\r') { + // need special case to skip over first \n of record if the line + // breaks are \r\n + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + *end_of_record = true; + pos_++; + if (ch == '\r') SkipNewLineIfNecessary(); + return parse_result; + } + if (dataset()->use_quote_delim_ && ch == '"') { + // Take note of the error, but keep going to end of field. + parse_result.Update(errors::InvalidArgument( + "Unquoted fields cannot have quotes inside")); + } + // Otherwise, go to next character + pos_++; + } + } + + Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + result->clear(); + Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); + + if (errors::IsOutOfRange(s) && !result->empty()) { + // Ignore OutOfRange error when ReadNBytes read < N bytes. + return Status::OK(); + } + return s; + } + + // Given a field, converts it to the right output tensor type + Status FieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors) { + size_t output_idx = out_tensors->size(); + if (output_idx >= dataset()->out_type_.size()) { + // We can get here if we're selecting all columns, but the number of + // fields exceeds the number of defaults provided + return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), + " fields but have more in record"); + } + const DataType& dtype = dataset()->out_type_[output_idx]; + Tensor component(ctx->allocator({}), dtype, {}); + if ((field.empty() || field == dataset()->na_value_) && + dataset()->record_defaults_[output_idx].NumElements() != 1) { + // If the field is empty or NA value, and default is not given, + // report error. + return errors::InvalidArgument("Field ", output_idx, + " is required but missing in record!"); + } + + switch (dtype) { + // For each case, if the field is empty, we use the default. + // Otherwise, we convert it to the right type. + case DT_INT32: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + int32 value; + if (!strings::safe_strto32(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int32: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_INT64: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + int64 value; + if (!strings::safe_strto64(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int64: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_FLOAT: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + float value; + if (!strings::safe_strtof(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid float: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_DOUBLE: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + double value; + if (!strings::safe_strtod(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid double: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_STRING: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + component.scalar()() = field.ToString(); + } + break; + } + default: + return errors::InvalidArgument("csv: data type ", dtype, + " not supported in field ", + output_idx); + } + out_tensors->push_back(std::move(component)); + return Status::OK(); + } + + // Records can be delimited by "\r\n" line breaks. When we encounter a + // '\r', we have to check the next character to see if it is part of the + // linebreak, and ignore it if so. + void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + // If we failed to fill buffer, it doesn't matter because we're done + // with the record + if (!s.ok()) return; + } + if (buffer_[pos_] == '\n') { + pos_++; + } + } + + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + return FieldToOutput(ctx, field, out_tensors); + } + + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + string field_complete; + field_complete.reserve(str_len); + + for (const Piece& p : earlier_pieces) { + field_complete.append(p.buffer, p.start, p.len); + } + + field_complete.append(field.data(), field.size()); + return FieldToOutput(ctx, field_complete, out_tensors); + } + + // Sets up reader streams to read from the file at `current_file_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + + // Actually move on to next file. + TF_RETURN_IF_ERROR(env->NewRandomAccessFile( + dataset()->filenames_[current_file_index_], &file_)); + input_stream_.reset( + new io::RandomAccessInputStream(file_.get(), false)); + buffer_.clear(); + pos_ = 0; + if (dataset()->header_) { + // Read one line, but don't include it. Pass nullptrs as dummy + // pointers to objects that shouldn't be invoked anyway + // We need to process this as a record here instead of just finding + // the first newline because it might contain quoted fields with + // newlines in the header as well + std::vector empty; + Status s = ReadRecord(nullptr, nullptr, false, empty); + if (!s.ok()) { + return errors::InvalidArgument("Can't read header of file"); + } + } + return Status::OK(); + } + + // Resets all reader streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + input_stream_.reset(); + file_.reset(); + } + + mutex mu_; + string buffer_ GUARDED_BY(mu_); // Maintain our own buffer + size_t pos_ GUARDED_BY( + mu_); // Index into the buffer must be maintained between iters + std::unique_ptr input_stream_ + GUARDED_BY(mu_); + size_t current_file_index_ GUARDED_BY(mu_) = 0; + std::unique_ptr file_ + GUARDED_BY(mu_); // must outlive input_stream_ + }; // class Iterator + + const std::vector filenames_; + const bool header_; + const int64 buffer_size_; + const DataTypeVector out_type_; + const std::vector output_shapes_; + const std::vector record_defaults_; + const std::vector select_cols_; + const bool use_quote_delim_; + const char delim_; + const string na_value_; + }; // class Dataset + + DataTypeVector output_types_; + std::vector output_shapes_; +}; // class CSVDatasetOp + +// Register the kernel implementation for CSVDataset. +REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc index 48d3734162525ffc6ace076e4f0523c1d0cae511..6a12ca06f4d6cc2096aaf8191a01a899881b43db 100644 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -91,7 +91,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { } } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( {this, strings::StrCat(prefix, "::DirectedInterleave")})); @@ -105,7 +105,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { return output_shapes_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); } @@ -130,15 +130,21 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator(params), - selector_input_impl_(params.dataset->selector_input_->MakeIterator( - params.prefix + ".selector")), - num_active_inputs_(params.dataset->data_inputs_.size()) { - data_input_impls_.reserve(params.dataset->data_inputs_.size()); - for (size_t i = 0; i < params.dataset->data_inputs_.size(); ++i) { - const DatasetBase* data_input = params.dataset->data_inputs_[i]; - data_input_impls_.push_back(data_input->MakeIterator( - strings::StrCat(params.prefix, "[", i, "]"))); + num_active_inputs_(params.dataset->data_inputs_.size()) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator( + ctx, strings::StrCat(prefix(), ".selector"), + &selector_input_impl_)); + data_input_impls_.resize(dataset()->data_inputs_.size()); + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const DatasetBase* data_input = dataset()->data_inputs_[i]; + TF_RETURN_IF_ERROR(data_input->MakeIterator( + ctx, strings::StrCat(prefix(), "[", i, "]"), + &data_input_impls_[i])); } + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index bb29df60e8f114aaa50f578c43e73874f72ab0a3..bbec50681c6f5decec5a3b5fbf09cc3011a21199 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -44,7 +44,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")})); @@ -57,7 +57,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "IgnoreErrorsDatasetOp::Dataset"; } + string DebugString() const override { + return "IgnoreErrorsDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, @@ -72,8 +74,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 63e19ae3f837c9d3cfb1221df64360ee74117f13..3dfc3741c2b040dd5be3223c24f0715ba3be4248 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -127,7 +127,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { threadpool_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::ThreadPool")})); @@ -140,7 +140,9 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "ThreadPoolDatasetOp::Dataset"; } + string DebugString() const override { + return "ThreadPoolDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, @@ -154,8 +156,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc index 69fbb0fcdcce87951d2c9b84210fda378081b103..67c237799c10a2724f18bb0df99e4bf8f5cd2b8a 100644 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -56,7 +56,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Unique")})); @@ -70,7 +70,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { + string DebugString() const override { return strings::StrCat("UniqueDatasetOp::Dataset"); } @@ -87,8 +87,11 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const typename Iterator::Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index 137deb63527f0bdde7da8d5be83ed038f430e581..f271d269ab1b9339de4657e459dcbbd462890f0a 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -34,6 +34,40 @@ data_input_datasets: `N` datasets with the same type that will be interleaved according to the values of `selector_input_dataset`. )doc"); +REGISTER_OP("CSVDataset") + .Input("filenames: string") + .Input("buffer_size: int64") + .Input("header: bool") + .Input("field_delim: string") + .Input("use_quote_delim: bool") + .Input("na_value: string") + .Input("select_cols: int64") + .Input("record_defaults: output_types") + .Output("handle: variant") + .Attr("output_types: list({float,double,int32,int64,string}) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + // `buffer_size`, `header`, `field_delim`, `use_quote_delim`, + // `na_value` must be scalars + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + // `select_cols` must be a vector + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &unused)); + // `record_defaults` must be a list of scalars...? + for (size_t i = 7; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused)); + } + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("IgnoreErrorsDataset") .Input("input_dataset: variant") .Output("handle: variant") diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 6017e27e731e3e8bcdee516ea291b17cd0782e63..be834d7dfdb6e143628642261508b56f5ee78395 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -11,7 +11,10 @@ py_test( size = "medium", srcs = ["batch_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_oss", # (b/79552534) + "no_pip", + ], deps = [ ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", @@ -117,6 +120,20 @@ py_library( ], ) +py_test( + name = "csv_dataset_op_test", + size = "small", + srcs = ["csv_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/contrib/data/python/ops:readers", + "//third_party/py/numpy", + ], +) + py_test( name = "filter_dataset_op_test", size = "small", @@ -192,6 +209,23 @@ py_test( ], ) +py_test( + name = "directed_interleave_dataset_test", + size = "medium", + srcs = ["directed_interleave_dataset_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "get_single_element_test", size = "small", @@ -246,6 +280,19 @@ py_test( ], ) +py_test( + name = "optimize_dataset_op_test", + size = "small", + srcs = ["optimize_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:platform", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "prefetch_dataset_op_test", size = "small", @@ -283,16 +330,37 @@ py_test( ], ) +py_library( + name = "reader_dataset_ops_test_base", + testonly = 1, + srcs = [ + "reader_dataset_ops_test_base.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:lib", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:readers", + ], +) + py_test( name = "reader_dataset_ops_test", size = "medium", srcs = ["reader_dataset_ops_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ ":dataset_serialization_test", + ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -301,8 +369,10 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", + "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", "//third_party/py/numpy", ], ) @@ -392,6 +462,7 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -399,6 +470,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", @@ -411,6 +483,7 @@ py_test( srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -428,10 +501,15 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", + ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 2568b899d7ea1be685036ad8af93f584f861c951..b5fbc45ad3d8d262c1c79b5723ffeb38ff6a34c2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -552,6 +552,44 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testMapAndBatchParallelGetNext(self): + iterator = (dataset_ops.Dataset.range(50000) + .apply(batching.map_and_batch(lambda x: x, batch_size=100)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(5): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + + def testMapAndBatchParallelGetNextDropRemainder(self): + iterator = ( + dataset_ops.Dataset.range(49999).apply( + batching.map_and_batch( + lambda x: x, batch_size=100, drop_remainder=True)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(4): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + def testMapAndBatchSparse(self): def _sparse(i): diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..97b5e9416521dcad9ee5047a8275f8fd0142e338 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -0,0 +1,619 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CsvDatasetOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import string +import tempfile +import time + +import numpy as np + +from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.client import session +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + + +class CsvDatasetOpTest(test.TestCase): + + def _assert_datasets_equal(self, g, ds1, ds2): + assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, ' + '%s') % (ds1.output_shapes, + ds2.output_shapes) + assert ds1.output_types == ds2.output_types + assert ds1.output_classes == ds2.output_classes + next1 = ds1.make_one_shot_iterator().get_next() + next2 = ds2.make_one_shot_iterator().get_next() + with self.test_session(graph=g) as sess: + # Run through datasets and check that outputs match, or errors match. + while True: + try: + op1 = sess.run(next1) + except (errors.OutOfRangeError, ValueError) as e: + # If op1 throws an exception, check that op2 throws same exception. + with self.assertRaises(type(e)): + sess.run(next2) + break + op2 = sess.run(next2) + self.assertAllEqual(op1, op2) + + def setup_files(self, inputs, linebreak='\n'): + filenames = [] + for i, ip in enumerate(inputs): + fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) + with open(fn, 'wb') as f: + f.write(linebreak.join(ip).encode('utf-8')) + filenames.append(fn) + return filenames + + def _make_test_datasets(self, inputs, **kwargs): + # Test by comparing its output to what we could get with map->decode_csv + filenames = self.setup_files(inputs) + dataset_expected = core_readers.TextLineDataset(filenames) + dataset_expected = dataset_expected.map( + lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) + dataset_actual = readers.CsvDataset(filenames, **kwargs) + return (dataset_actual, dataset_expected) + + def _test_by_comparison(self, inputs, **kwargs): + """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" + with ops.Graph().as_default() as g: + dataset_actual, dataset_expected = self._make_test_datasets( + inputs, **kwargs) + self._assert_datasets_equal(g, dataset_actual, dataset_expected) + + def _verify_output_or_err(self, + sess, + dataset, + expected_output=None, + expected_err_re=None): + nxt = dataset.make_one_shot_iterator().get_next() + if expected_err_re is None: + # Verify that output is expected, without errors + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = sess.run(nxt) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + while True: + try: + sess.run(nxt) + except errors.OutOfRangeError: + break + + def _test_dataset(self, + inputs, + expected_output=None, + expected_err_re=None, + linebreak='\n', + **kwargs): + """Checks that elements produced by CsvDataset match expected output.""" + # Convert str type because py3 tf strings are bytestrings + filenames = self.setup_files(inputs, linebreak) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, **kwargs) + self._verify_output_or_err(sess, dataset, expected_output, + expected_err_re) + + def testCsvDataset_requiredFields(self): + record_defaults = [[]] * 4 + inputs = [['1,2,3,4']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_int(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_float(self): + record_defaults = [[0.0]] * 4 + inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_string(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withEmptyFields(self): + record_defaults = [[0]] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_errWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_dataset( + inputs, + expected_err_re='Unquoted fields cannot have quotes inside', + record_defaults=record_defaults) + + def testCsvDataset_errWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['"a"b","c","d"']] + self._test_dataset( + inputs, + expected_err_re= + 'Quote inside a string has to be escaped by another quote', + record_defaults=record_defaults) + + def testCsvDataset_ignoreErrWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + + def testCsvDataset_ignoreErrWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + + def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) + + def testCsvDataset_mixedTypes(self): + record_defaults = [ + constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.float32), + constant_op.constant([], dtype=dtypes.string), + constant_op.constant([], dtype=dtypes.float64) + ] + inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withUseQuoteDelimFalse(self): + record_defaults = [['']] * 4 + inputs = [['1,2,"3,4"', '"5,6",7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) + + def testCsvDataset_withFieldDelim(self): + record_defaults = [[0]] * 4 + inputs = [['1:2:3:4', '5:6:7:8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, field_delim=':') + + def testCsvDataset_withNaValue(self): + record_defaults = [[0]] * 4 + inputs = [['1,NA,3,4', 'NA,6,7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, na_value='NA') + + def testCsvDataset_withSelectCols(self): + record_defaults = [['']] * 2 + inputs = [['1,2,3,4', '"5","6","7","8"']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, select_cols=[1, 2]) + + def testCsvDataset_withSelectColsTooHigh(self): + record_defaults = [[0]] * 2 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 1 in record', + record_defaults=record_defaults, + select_cols=[3, 4]) + + def testCsvDataset_withOneCol(self): + record_defaults = [['NA']] + inputs = [['0', '', '2']] + self._test_dataset( + inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withMultipleFiles(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withLeadingAndTrailingSpaces(self): + record_defaults = [[0.0]] * 4 + inputs = [['0, 1, 2, 3']] + expected = [[0.0, 1.0, 2.0, 3.0]] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithMissingDefault(self): + record_defaults = [[]] * 2 + inputs = [['0,']] + self._test_dataset( + inputs, + expected_err_re='Field 1 is required but missing in record!', + record_defaults=record_defaults) + + def testCsvDataset_errorWithFewerDefaultsThanFields(self): + record_defaults = [[0.0]] * 2 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have more in record', + record_defaults=record_defaults) + + def testCsvDataset_errorWithMoreDefaultsThanFields(self): + record_defaults = [[0.0]] * 5 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 5 fields but have 4 in record', + record_defaults=record_defaults) + + def testCsvDataset_withHeader(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2', '1,2']] + expected = [[1, 2]] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withHeaderAndNoRecords(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2']] + expected = [] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_errorWithHeaderEmptyFile(self): + record_defaults = [[0]] * 2 + inputs = [[]] + expected_err_re = "Can't read header of file" + self._test_dataset( + inputs, + expected_err_re=expected_err_re, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withEmptyFile(self): + record_defaults = [['']] * 2 + inputs = [['']] # Empty file + self._test_dataset( + inputs, expected_output=[], record_defaults=record_defaults) + + def testCsvDataset_errorWithEmptyRecord(self): + record_defaults = [['']] * 2 + inputs = [['', '1,2']] # First record is empty + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 1 in record', + record_defaults=record_defaults) + + def testCsvDataset_withChainedOps(self): + # Testing that one dataset can create multiple iterators fine. + # `repeat` creates multiple iterators from the same C++ Dataset. + record_defaults = [[0]] * 4 + inputs = [['1,,3,4', '5,6,,8']] + ds_actual, ds_expected = self._make_test_datasets( + inputs, record_defaults=record_defaults) + with ops.Graph().as_default() as g: + self._assert_datasets_equal(g, + ds_actual.repeat(5).prefetch(1), + ds_expected.repeat(5).prefetch(1)) + + def testCsvDataset_withTypeDefaults(self): + # Testing using dtypes as record_defaults for required fields + record_defaults = [dtypes.float32, [0.0]] + inputs = [['1.0,2.0', '3.0,4.0']] + self._test_dataset( + inputs, + [[1.0, 2.0], [3.0, 4.0]], + record_defaults=record_defaults, + ) + + def testMakeCsvDataset_fieldOrder(self): + data = [[ + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19' + ]] + file_path = self.setup_files(data) + + with ops.Graph().as_default() as g: + ds = readers.make_csv_dataset( + file_path, batch_size=1, shuffle=False, num_epochs=1) + next_batch = ds.make_one_shot_iterator().get_next() + + with self.test_session(graph=g) as sess: + result = list(sess.run(next_batch).values()) + + self.assertEqual(result, sorted(result)) + +## The following tests exercise parsing logic for quoted fields + + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withOneColAndQuotes(self): + record_defaults = [['']] + inputs = [['"0"', '"1"', '"2"']] + self._test_dataset( + inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withNewLineInUnselectedCol(self): + record_defaults = [['']] + inputs = [['1,"2\n3",4', '5,6,7']] + self._test_dataset( + inputs, + expected_output=[['1'], ['5']], + record_defaults=record_defaults, + select_cols=[0]) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithTerminateMidRecord(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,"a']] + self._test_dataset( + inputs, + expected_err_re= + 'Reached end of file without closing quoted field in record', + record_defaults=record_defaults) + + def testCsvDataset_withEscapedQuotes(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + +## Testing that parsing works with all buffer sizes, quoted/unquoted fields, +## and different types of line breaks + + def testCsvDataset_withInvalidBufferSize(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,d']] + self._test_dataset( + inputs, + expected_err_re='buffer_size should be positive', + record_defaults=record_defaults, + buffer_size=0) + + def testCsvDataset_withBufferSize(self): + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, expected, record_defaults=record_defaults, buffer_size=i + 1) + + def testCsvDataset_withCR(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withCRLF(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withBufferSizeAndQuoted(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\n', record_defaults=record_defaults) + + def testCsvDataset_withCRAndQuoted(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r', record_defaults=record_defaults) + + def testCsvDataset_withCRLFAndQuoted(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + + +class CsvDatasetBenchmark(test.Benchmark): + """Benchmarks for the various ways of creating a dataset from CSV files. + """ + FLOAT_VAL = '1.23456E12' + STR_VAL = string.ascii_letters * 10 + + def _setUp(self, str_val): + # Since this isn't test.TestCase, have to manually create a test dir + gfile.MakeDirs(googletest.GetTempDir()) + self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) + + self._num_cols = [4, 64, 256] + self._num_per_iter = 5000 + self._filenames = [] + for n in self._num_cols: + fn = os.path.join(self._temp_dir, 'file%d.csv' % n) + with open(fn, 'wb') as f: + # Just write 100 rows and use `repeat`... Assumes the cost + # of creating an iterator is not significant + row = ','.join([str_val for _ in range(n)]) + f.write('\n'.join([row for _ in range(100)])) + self._filenames.append(fn) + + def _tearDown(self): + gfile.DeleteRecursively(self._temp_dir) + + def _runBenchmark(self, dataset, num_cols, prefix): + dataset = dataset.skip(self._num_per_iter - 1) + deltas = [] + for _ in range(10): + next_element = dataset.make_one_shot_iterator().get_next() + with session.Session() as sess: + start = time.time() + # NOTE: This depends on the underlying implementation of skip, to have + # the net effect of calling `GetNext` num_per_iter times on the + # input dataset. We do it this way (instead of a python for loop, or + # batching N inputs in one iter) so that the overhead from session.run + # or batch doesn't dominate. If we eventually optimize skip, this has + # to change. + sess.run(next_element) + end = time.time() + deltas.append(end - start) + # Median wall time per CSV record read and decoded + median_wall_time = np.median(deltas) / self._num_per_iter + print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols, + median_wall_time)) + self.report_benchmark( + iters=self._num_per_iter, + wall_time=median_wall_time, + name='%s_with_cols_%d' % (prefix, num_cols)) + + def benchmarkMapWithFloats(self): + self._setUp(self.FLOAT_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv') + self._tearDown() + + def benchmarkMapWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv') + self._tearDown() + + def benchmarkCsvDatasetWithFloats(self): + self._setUp(self.FLOAT_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset') + self._tearDown() + + def benchmarkCsvDatasetWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset') + self._tearDown() + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index 78ecce8f7daaf84002ae78d8d77820755b967d89..393f08850b1865180a8b94e9209b2445b54c8b69 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -467,7 +467,8 @@ class DatasetSerializationTestBase(test.TestCase): ckpt_saved=False, init_before_restore=False, sparse_tensors=False, - verify_exhausted=True): + verify_exhausted=True, + save_checkpoint_at_end=True): """Generates elements from input dataset while stopping at break points. Produces `num_outputs` outputs and saves the state of the iterator in the @@ -490,6 +491,10 @@ class DatasetSerializationTestBase(test.TestCase): sparse_tensors: Whether dataset is built from SparseTensor(s). verify_exhausted: Whether to verify that the iterator has been exhausted after producing `num_outputs` elements. + save_checkpoint_at_end: Whether to save a checkpoint after producing all + outputs. If False, checkpoints are saved each break point but not at the + end. Note that checkpoints overwrite each other so there is always only + a single checkpoint available. Defaults to True. Returns: A list of `num_outputs` items. @@ -526,8 +531,9 @@ class DatasetSerializationTestBase(test.TestCase): if i == len(break_points) and verify_exhausted: with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) - self._save(sess, saver) - ckpt_saved = True + if save_checkpoint_at_end or i < len(break_points): + self._save(sess, saver) + ckpt_saved = True return outputs diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..34b6a080c0aae7dfc228746139acc52cea4e6f28 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed +from tensorflow.python.platform import test + + +class DirectedInterleaveDatasetTest(test.TestCase): + + def testBasic(self): + selector_dataset = dataset_ops.Dataset.range(10).repeat(100) + input_datasets = [ + dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) + ] + dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, + input_datasets) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(100): + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def _normalize(self, vec): + return vec / vec.sum() + + def _chi2(self, expected, actual): + actual = np.asarray(actual) + expected = np.asarray(expected) + diff = actual - expected + chi2 = np.sum(diff * diff / expected, axis=0) + return chi2 + + def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): + # Create a dataset that samples each integer in `[0, num_datasets)` + # with probability given by `weights[i]`. + dataset = interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(num_datasets) + ], weights) + dataset = dataset.take(num_samples) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + freqs = np.zeros([num_datasets]) + for _ in range(num_samples): + freqs[sess.run(next_element)] += 1 + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + return freqs + + def testSampleFromDatasets(self): + random_seed.set_random_seed(1619) + num_samples = 5000 + rand_probs = self._normalize(np.random.random_sample((15,))) + + # Use chi-squared test to assert that the observed distribution matches the + # expected distribution. Based on the implementation in + # "tensorflow/python/kernel_tests/multinomial_op_test.py". + for probs in [[.85, .05, .1], rand_probs]: + probs = np.asarray(probs) + classes = len(probs) + freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + # Also check that `weights` as a dataset samples correctly. + probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() + freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + def testSelectFromDatasets(self): + words = [b"foo", b"bar", b"baz"] + datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] + choice_array = np.random.randint(3, size=(15,), dtype=np.int64) + choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) + dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in choice_array: + self.assertEqual(words[i], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testErrors(self): + with self.assertRaisesRegexp(ValueError, + r"vector of length `len\(datasets\)`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[0.25, 0.25, 0.25, 0.25]) + + with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[1, 1]) + + with self.assertRaisesRegexp(TypeError, "must have the same type"): + interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(0.0) + ]) + + with self.assertRaisesRegexp(TypeError, "tf.int64"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0)) + + with self.assertRaisesRegexp(TypeError, "scalar"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0])) + + +class SampleFromDatasetsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, probs, num_samples): + dataset = interleave_ops.sample_from_datasets( + [ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(len(probs)) + ], + probs, + seed=1813) + return dataset.take(num_samples) + + def testSerializationCore(self): + self.run_core_tests( + lambda: self._build_dataset([0.5, 0.5], 100), + lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 43aa4b1bd02791ff304a990c0bbe8e45534c0c77..bee561e3e23a2ab6f314894caa21785347e6ca8b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -30,7 +30,6 @@ from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -907,114 +906,5 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(self.next_element) -class DirectedInterleaveDatasetTest(test.TestCase): - - def testBasic(self): - selector_dataset = dataset_ops.Dataset.range(10).repeat(100) - input_datasets = [ - dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) - ] - dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, - input_datasets) - iterator = dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - sess.run(iterator.initializer) - for _ in range(100): - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def _normalize(self, vec): - return vec / vec.sum() - - def _chi2(self, expected, actual): - actual = np.asarray(actual) - expected = np.asarray(expected) - diff = actual - expected - chi2 = np.sum(diff * diff / expected, axis=0) - return chi2 - - def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): - # Create a dataset that samples each integer in `[0, num_datasets)` - # with probability given by `weights[i]`. - dataset = interleave_ops.sample_from_datasets([ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(num_datasets) - ], weights) - dataset = dataset.take(num_samples) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - freqs = np.zeros([num_datasets]) - for _ in range(num_samples): - freqs[sess.run(next_element)] += 1 - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - return freqs - - def testSampleFromDatasets(self): - random_seed.set_random_seed(1619) - num_samples = 10000 - rand_probs = self._normalize(np.random.random_sample((15,))) - - # Use chi-squared test to assert that the observed distribution matches the - # expected distribution. Based on the implementation in - # "tensorflow/python/kernel_tests/multinomial_op_test.py". - for probs in [[.85, .05, .1], rand_probs]: - probs = np.asarray(probs) - classes = len(probs) - freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) - self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) - - # Also check that `weights` as a dataset samples correctly. - probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() - freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) - self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) - - def testErrors(self): - with self.assertRaisesRegexp(ValueError, - r"vector of length `len\(datasets\)`"): - interleave_ops.sample_from_datasets( - [dataset_ops.Dataset.range(10), - dataset_ops.Dataset.range(20)], - weights=[0.25, 0.25, 0.25, 0.25]) - - with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): - interleave_ops.sample_from_datasets( - [dataset_ops.Dataset.range(10), - dataset_ops.Dataset.range(20)], - weights=[1, 1]) - - with self.assertRaisesRegexp(TypeError, "must have the same type"): - interleave_ops.sample_from_datasets([ - dataset_ops.Dataset.from_tensors(0), - dataset_ops.Dataset.from_tensors(0.0) - ]) - - -class SampleFromDatasetsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, probs, num_samples): - dataset = interleave_ops.sample_from_datasets( - [ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(len(probs)) - ], - probs, - seed=1813) - return dataset.take(num_samples) - - def testSerializationCore(self): - self.run_core_tests( - lambda: self._build_dataset([0.5, 0.5], 100), - lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..30f1847dcddbfaf379ef2b09185f7a8db4aaeae2 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class OptimizeDatasetTest(test.TestCase): + + def testDefaultOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize()) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testEmptyOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize([])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimization(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + any([node.op == "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +class OptimizeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( + batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 1075302bae96ca2e0111efbacdf5e919ea76897d..3b07ef290bc38daa37472ef8919f3350851fe370 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -24,9 +24,8 @@ import zlib import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import readers -from tensorflow.core.example import example_pb2 -from tensorflow.core.example import feature_pb2 from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import constant_op @@ -36,6 +35,7 @@ from tensorflow.python.framework import ops from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -256,185 +256,31 @@ class TFRecordDatasetSerializationTest( lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) -class ReadBatchFeaturesTest(test.TestCase): - - def setUp(self): - super(ReadBatchFeaturesTest, self).setUp() - self._num_files = 2 - self._num_records = 7 - self.test_filenames = self._createFiles() - - def _read_batch_features(self, - filenames, - num_epochs, - batch_size, - reader_num_threads=1, - parser_num_threads=1, - shuffle=False, - shuffle_seed=None, - drop_final_batch=False): - self.filenames = filenames - self.num_epochs = num_epochs - self.batch_size = batch_size - - return readers.make_batched_features_dataset( - file_pattern=self.filenames, - batch_size=self.batch_size, - features={ - "file": parsing_ops.FixedLenFeature([], dtypes.int64), - "record": parsing_ops.FixedLenFeature([], dtypes.int64), - "keywords": parsing_ops.VarLenFeature(dtypes.string) - }, - reader=core_readers.TFRecordDataset, - num_epochs=self.num_epochs, - shuffle=shuffle, - shuffle_seed=shuffle_seed, - reader_num_threads=reader_num_threads, - parser_num_threads=parser_num_threads, - drop_final_batch=drop_final_batch).make_one_shot_iterator( - ).get_next() - - def _record(self, f, r): - example = example_pb2.Example( - features=feature_pb2.Features( - feature={ - "file": - feature_pb2.Feature( - int64_list=feature_pb2.Int64List(value=[f])), - "record": - feature_pb2.Feature( - int64_list=feature_pb2.Int64List(value=[r])), - "keywords": - feature_pb2.Feature( - bytes_list=feature_pb2.BytesList( - value=self._get_keywords(f, r))) - })) - return example.SerializeToString() - - def _get_keywords(self, f, r): - num_keywords = 1 + (f + r) % 2 - keywords = [] - for index in range(num_keywords): - keywords.append(compat.as_bytes("keyword%d" % index)) - return keywords - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = python_io.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._record(i, j)) - writer.close() - return filenames - - def _run_actual_batch(self, outputs, sess): - file_op = outputs["file"] - keywords_indices_op = outputs["keywords"].indices - keywords_values_op = outputs["keywords"].values - keywords_dense_shape_op = outputs["keywords"].dense_shape - record_op = outputs["record"] - return sess.run([ - file_op, keywords_indices_op, keywords_values_op, - keywords_dense_shape_op, record_op - ]) - - def _next_actual_batch(self, sess): - return self._run_actual_batch(self.outputs, sess) - - def _next_expected_batch(self, - file_indices, - batch_size, - num_epochs, - cycle_length=1): - - def _next_record(file_indices): - for j in file_indices: - for i in range(self._num_records): - yield j, i - - def _next_record_interleaved(file_indices, cycle_length): - return self._interleave([_next_record([i]) for i in file_indices], - cycle_length) - - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - for _ in range(num_epochs): - if cycle_length == 1: - next_records = _next_record(file_indices) - else: - next_records = _next_record_interleaved(file_indices, cycle_length) - for record in next_records: - f = record[0] - r = record[1] - file_batch.append(f) - record_batch.append(r) - keywords = self._get_keywords(f, r) - keywords_batch_values.extend(keywords) - keywords_batch_indices.extend( - [[batch_index, i] for i in range(len(keywords))]) - batch_index += 1 - keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) - if len(file_batch) == batch_size: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [batch_size, keywords_batch_max_len], record_batch - ] - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - if file_batch: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [len(file_batch), keywords_batch_max_len], record_batch - ] - - def _interleave(self, iterators, cycle_length): - pending_iterators = iterators - open_iterators = [] - num_open = 0 - for i in range(cycle_length): - if pending_iterators: - open_iterators.append(pending_iterators.pop(0)) - num_open += 1 - - while num_open: - for i in range(min(cycle_length, len(open_iterators))): - if open_iterators[i] is None: - continue - try: - yield next(open_iterators[i]) - except StopIteration: - if pending_iterators: - open_iterators[i] = pending_iterators.pop(0) - else: - open_iterators[i] = None - num_open -= 1 - - def _verify_records(self, - sess, - batch_size, - file_index=None, - num_epochs=1, - interleave_cycle_length=1): - if file_index is not None: - file_indices = [file_index] - else: - file_indices = range(self._num_files) - - for expected_batch in self._next_expected_batch( - file_indices, batch_size, num_epochs, interleave_cycle_length): - actual_batch = self._next_actual_batch(sess) - for i in range(len(expected_batch)): - self.assertAllEqual(expected_batch[i], actual_batch[i]) +def _interleave(iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + +class ReadBatchFeaturesTest( + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): def testRead(self): for batch_size in [1, 2]: @@ -442,33 +288,33 @@ class ReadBatchFeaturesTest(test.TestCase): with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 0. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 0, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 0, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 1. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[1], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 1, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 1, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from both files. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames, num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) @@ -502,18 +348,18 @@ class ReadBatchFeaturesTest(test.TestCase): # Test that shuffling with same seed produces the same result. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - outputs1 = self._read_batch_features( + outputs1 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) - outputs2 = self._read_batch_features( + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) + shuffle_seed=5).make_one_shot_iterator().get_next() for _ in range(total_records // batch_size): batch1 = self._run_actual_batch(outputs1, sess) batch2 = self._run_actual_batch(outputs2, sess) @@ -523,18 +369,18 @@ class ReadBatchFeaturesTest(test.TestCase): # Test that shuffling with different seeds produces a different order. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - outputs1 = self._read_batch_features( + outputs1 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) - outputs2 = self._read_batch_features( + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=15) + shuffle_seed=15).make_one_shot_iterator().get_next() all_equal = True for _ in range(total_records // batch_size): batch1 = self._run_actual_batch(outputs1, sess) @@ -550,13 +396,14 @@ class ReadBatchFeaturesTest(test.TestCase): for parser_num_threads in [2, 4]: with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames, num_epochs=num_epochs, batch_size=batch_size, reader_num_threads=reader_num_threads, - parser_num_threads=parser_num_threads) - self._verify_records( + parser_num_threads=parser_num_threads).make_one_shot_iterator( + ).get_next() + self.verify_records( sess, batch_size, num_epochs=num_epochs, @@ -569,11 +416,11 @@ class ReadBatchFeaturesTest(test.TestCase): for num_epochs in [1, 10]: with ops.Graph().as_default(): # Basic test: read from file 0. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, - drop_final_batch=True) + drop_final_batch=True).make_one_shot_iterator().get_next() for _, tensor in self.outputs.items(): if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) @@ -620,14 +467,12 @@ class MakeCsvDatasetTest(test.TestCase): f.close() return fn - def _create_file(self, fileno, header=True, comment=True): + def _create_file(self, fileno, header=True): rows = [] if header: rows.append(self.COLUMNS) for recno in range(self._num_records): rows.append(self._csv_values(fileno, recno)) - if comment: - rows.append("# Some comment goes here. Ignore me.") return self._write_file("csv_file%d.csv" % fileno, rows) def _create_files(self): @@ -648,9 +493,7 @@ class MakeCsvDatasetTest(test.TestCase): shuffle=False, shuffle_seed=None, header=True, - comment="#", na_value="", - default_float_type=dtypes.float32, ): return readers.make_csv_dataset( filenames, @@ -662,9 +505,7 @@ class MakeCsvDatasetTest(test.TestCase): shuffle=shuffle, shuffle_seed=shuffle_seed, header=header, - comment=comment, na_value=na_value, - default_float_type=default_float_type, select_columns=select_cols, ) @@ -786,29 +627,6 @@ class MakeCsvDatasetTest(test.TestCase): num_epochs=10, label_name=None) - def testMakeCSVDataset_withNoComments(self): - """Tests that datasets can be created from CSV files with no header line. - """ - defaults = self.DEFAULTS - file_without_header = self._create_file( - len(self._test_filenames), comment=False) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - file_without_header, - defaults, - batch_size=2, - num_epochs=10, - comment=None, - ) - self._verify_records( - sess, - dataset, - [len(self._test_filenames)], - batch_size=2, - num_epochs=10, - ) - def testMakeCSVDataset_withNoHeader(self): """Tests that datasets can be created from CSV files with no header line. """ @@ -876,7 +694,7 @@ class MakeCsvDatasetTest(test.TestCase): In that case, we should infer the types from the first N records. """ - # Test that it works with standard test files (with comments, header, etc) + # Test that it works with standard test files (with header, etc) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = self._make_csv_dataset( @@ -889,7 +707,9 @@ class MakeCsvDatasetTest(test.TestCase): num_epochs=10, defaults=[[], [], [], [], [""]]) - # Test on a deliberately tricky file + def testMakeCSVDataset_withTypeInferenceTricky(self): + # Test on a deliberately tricky file (type changes as we read more rows, and + # there are null values) fn = os.path.join(self.get_temp_dir(), "file.csv") expected_dtypes = [ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32, @@ -914,20 +734,29 @@ class MakeCsvDatasetTest(test.TestCase): column_names=None, label_name=None, na_value="NAN", - default_float_type=dtypes.float32, ) features = dataset.make_one_shot_iterator().get_next() # Check that types match for i in range(len(expected_dtypes)): + print(features["col%d" % i].dtype, expected_dtypes[i]) assert features["col%d" % i].dtype == expected_dtypes[i] for i in range(len(rows)): assert sess.run(features) == dict(zip(col_names, expected[i])) - # With float64 as default type for floats + def testMakeCSVDataset_withTypeInferenceAllTypes(self): + # Test that we make the correct inference for all types with fallthrough + fn = os.path.join(self.get_temp_dir(), "file.csv") expected_dtypes = [ - dtypes.int32, dtypes.int64, dtypes.float64, dtypes.float64, + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string, dtypes.string ] + col_names = ["col%d" % i for i in range(len(expected_dtypes))] + rows = [[1, 2**31 + 1, 1.0, 4e40, "abc", ""]] + expected = [[ + 1, 2**31 + 1, 1.0, 4e40, "abc".encode("utf-8"), "".encode("utf-8") + ]] + self._write_file("file.csv", [col_names] + rows) + with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = self._make_csv_dataset( @@ -936,7 +765,6 @@ class MakeCsvDatasetTest(test.TestCase): column_names=None, label_name=None, na_value="NAN", - default_float_type=dtypes.float64, ) features = dataset.make_one_shot_iterator().get_next() # Check that types match @@ -1086,5 +914,189 @@ class MakeCsvDatasetTest(test.TestCase): self.assertFalse(all_equal) +class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length, + drop_final_batch, + use_parser_fn): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return _interleave([_next_record([i]) for i in file_indices], + cycle_length) + + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for f, r in next_records: + record = self._record(f, r) + if use_parser_fn: + record = record[1:] + record_batch.append(record) + batch_index += 1 + if len(record_batch) == batch_size: + yield record_batch + record_batch = [] + batch_index = 0 + if record_batch and not drop_final_batch: + yield record_batch + + def _verify_records(self, + sess, + outputs, + batch_size, + file_index, + num_epochs, + interleave_cycle_length, + drop_final_batch, + use_parser_fn): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length, + drop_final_batch, use_parser_fn): + actual_batch = sess.run(outputs) + self.assertAllEqual(expected_batch, actual_batch) + + def _read_test(self, batch_size, num_epochs, file_index=None, + num_parallel_reads=1, drop_final_batch=False, parser_fn=False): + if file_index is None: + file_pattern = self.test_filenames + else: + file_pattern = self.test_filenames[file_index] + + if parser_fn: + fn = lambda x: string_ops.substr(x, 1, 999) + else: + fn = None + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs = readers.make_tf_record_dataset( + file_pattern=file_pattern, + num_epochs=num_epochs, + batch_size=batch_size, + parser_fn=fn, + num_parallel_reads=num_parallel_reads, + drop_final_batch=drop_final_batch, + shuffle=False).make_one_shot_iterator().get_next() + self._verify_records( + sess, outputs, batch_size, file_index, num_epochs=num_epochs, + interleave_cycle_length=num_parallel_reads, + drop_final_batch=drop_final_batch, use_parser_fn=parser_fn) + with self.assertRaises(errors.OutOfRangeError): + sess.run(outputs) + + def testRead(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + # Basic test: read from file 0. + self._read_test(batch_size, num_epochs, 0) + + # Basic test: read from file 1. + self._read_test(batch_size, num_epochs, 1) + + # Basic test: read from both files. + self._read_test(batch_size, num_epochs) + + # Basic test: read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8) + + def testDropFinalBatch(self): + for batch_size in [1, 2, 10]: + for num_epochs in [1, 3]: + # Read from file 0. + self._read_test(batch_size, num_epochs, 0, drop_final_batch=True) + + # Read from both files. + self._read_test(batch_size, num_epochs, drop_final_batch=True) + + # Read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + drop_final_batch=True) + + def testParserFn(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for drop_final_batch in [False, True]: + self._read_test(batch_size, num_epochs, parser_fn=True, + drop_final_batch=drop_final_batch) + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + parser_fn=True, drop_final_batch=drop_final_batch) + + def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1, + seed=None): + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.make_tf_record_dataset( + file_pattern=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size, + num_parallel_reads=num_parallel_reads, + shuffle=True, + shuffle_seed=seed) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + sess.run(iterator.initializer) + first_batches = [] + try: + while True: + first_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + sess.run(iterator.initializer) + second_batches = [] + try: + while True: + second_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + self.assertEqual(len(first_batches), len(second_batches)) + if seed is not None: + # if you set a seed, should get the same results + for i in range(len(first_batches)): + self.assertAllEqual(first_batches[i], second_batches[i]) + + expected = [] + for f in range(self._num_files): + for r in range(self._num_records): + expected.extend([self._record(f, r)] * num_epochs) + + for batches in (first_batches, second_batches): + actual = [] + for b in batches: + actual.extend(b) + self.assertAllEqual(sorted(expected), sorted(actual)) + + def testShuffle(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for num_parallel_reads in [1, 2]: + # Test that all expected elements are produced + self._shuffle_test(batch_size, num_epochs, num_parallel_reads) + # Test that elements are produced in a consistent order if + # you specify a seed. + self._shuffle_test(batch_size, num_epochs, num_parallel_reads, + seed=21345) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..805a7c7b7384d53cc166a48ba243502ef8643280 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py @@ -0,0 +1,218 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class ReadBatchFeaturesTestBase(test.TestCase): + """Base class for setting up and testing `make_batched_feature_dataset`.""" + + def setUp(self): + super(ReadBatchFeaturesTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self.test_filenames = self._createFiles() + + def make_batch_feature(self, + filenames, + num_epochs, + batch_size, + reader_num_threads=1, + parser_num_threads=1, + shuffle=False, + shuffle_seed=None, + drop_final_batch=False): + self.filenames = filenames + self.num_epochs = num_epochs + self.batch_size = batch_size + + return readers.make_batched_features_dataset( + file_pattern=self.filenames, + batch_size=self.batch_size, + features={ + "file": parsing_ops.FixedLenFeature([], dtypes.int64), + "record": parsing_ops.FixedLenFeature([], dtypes.int64), + "keywords": parsing_ops.VarLenFeature(dtypes.string) + }, + reader=core_readers.TFRecordDataset, + num_epochs=self.num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads, + drop_final_batch=drop_final_batch) + + def _record(self, f, r): + example = example_pb2.Example( + features=feature_pb2.Features( + feature={ + "file": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[f])), + "record": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[r])), + "keywords": + feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=self._get_keywords(f, r))) + })) + return example.SerializeToString() + + def _get_keywords(self, f, r): + num_keywords = 1 + (f + r) % 2 + keywords = [] + for index in range(num_keywords): + keywords.append(compat.as_bytes("keyword%d" % index)) + return keywords + + def _sum_keywords(self, num_files): + sum_keywords = 0 + for i in range(num_files): + for j in range(self._num_records): + sum_keywords += 1 + (i + j) % 2 + return sum_keywords + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j)) + writer.close() + return filenames + + def _run_actual_batch(self, outputs, sess): + file_op = outputs["file"] + keywords_indices_op = outputs["keywords"].indices + keywords_values_op = outputs["keywords"].values + keywords_dense_shape_op = outputs["keywords"].dense_shape + record_op = outputs["record"] + return sess.run([ + file_op, keywords_indices_op, keywords_values_op, + keywords_dense_shape_op, record_op + ]) + + def _next_actual_batch(self, sess): + return self._run_actual_batch(self.outputs, sess) + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length=1): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) + + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for record in next_records: + f = record[0] + r = record[1] + file_batch.append(f) + record_batch.append(r) + keywords = self._get_keywords(f, r) + keywords_batch_values.extend(keywords) + keywords_batch_indices.extend( + [[batch_index, i] for i in range(len(keywords))]) + batch_index += 1 + keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) + if len(file_batch) == batch_size: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [batch_size, keywords_batch_max_len], record_batch + ] + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + if file_batch: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [len(file_batch), keywords_batch_max_len], record_batch + ] + + def verify_records(self, + sess, + batch_size, + file_index=None, + num_epochs=1, + interleave_cycle_length=1): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length): + actual_batch = self._next_actual_batch(sess) + for i in range(len(expected_batch)): + self.assertAllEqual(expected_batch[i], actual_batch[i]) diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index bcc644c0971854d948025009dc7add2fea214048..1b67a33f04b0c2ac80402e163005123a4b3e4400 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -20,11 +20,13 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib class ShuffleDatasetSerializationTest( @@ -50,26 +52,100 @@ class ShuffleDatasetSerializationTest( num_repeats = 5 num_outputs = range_limit * num_repeats buffer_sizes = [1, 3, 8, 10, 25, 50] - reshuffle_each_iteration = False # pylint: disable=cell-var-from-loop # pylint: disable=g-long-lambda - for buffer_size in buffer_sizes: - self.run_core_tests( - lambda: self._build_shuffle_dataset( + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + self.run_core_tests( + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration), + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=10, + reshuffle_each_iteration=reshuffle_each_iteration), + num_outputs) + # pylint: enable=cell-var-from-loop + # pylint: enable=g-long-lambda + + def testNonDeterministicSeeding(self): + + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration), - lambda: self._build_shuffle_dataset( + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + # We checkpoint the initial state of the Dataset so that we can restore + # the seeds in the next run. Since the seeding is non-deterministic + # the dataset gets initialized with different seeds each time. + expected = self.gen_outputs( + ds_fn, + break_points=[0], + num_outputs=num_outputs, + ckpt_saved=False, + verify_exhausted=False, + save_checkpoint_at_end=False) + actual = self.gen_outputs( + ds_fn, + break_points=self.gen_break_points(num_outputs), + num_outputs=num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.match(expected, actual) + + def testMultipleIterators(self): + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, - seed=10, - reshuffle_each_iteration=reshuffle_each_iteration), - num_outputs) - # pylint: enable=cell-var-from-loop - # pylint: enable=g-long-lambda + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + with ops.Graph().as_default() as g: + ds = ds_fn() + iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] + get_next_ops = [it.get_next() for it in iterators] + saveables = [ + contrib_iterator_ops.make_saveable_from_iterator(it) + for it in iterators + ] + for saveable in saveables: + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saver = saver_lib.Saver(allow_empty=True) + with self.test_session(graph=g) as sess: + self._save(sess, saver) + expected = [sess.run(get_next_ops) for _ in range(num_outputs)] + self._restore(saver, sess) + actual = [sess.run(get_next_ops) for _ in range(num_outputs)] + self.match(expected, actual) class ShuffleAndRepeatTest( diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index e26cef8ec522c7e69a0c19b2b30a969bbfc0ad78..4148addf2878c99f47ebe1454edf69ad7f38dfbc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -22,6 +22,7 @@ import os import sqlite3 +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetTest(test.TestCase): +class SqlDatasetTestBase(test.TestCase): def _createSqlDataset(self, output_types, num_repeats=1): dataset = readers.SqlDataset(self.driver_name, self.data_source_name, @@ -92,6 +93,9 @@ class SqlDatasetTest(test.TestCase): conn.commit() conn.close() + +class SqlDatasetTest(SqlDatasetTestBase): + # Test that SqlDataset can read from a database table. def testReadResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, @@ -652,5 +656,27 @@ class SqlDatasetTest(test.TestCase): sess.run(get_next) +class SqlDatasetSerializationTest( + SqlDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_repeats): + data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + query = ("SELECT first_name, last_name, motto FROM students ORDER BY " + "first_name DESC") + output_types = (dtypes.string, dtypes.string, dtypes.string) + return readers.SqlDataset(driver_name, data_source_name, query, + output_types).repeat(num_repeats) + + def testSQLSaveable(self): + num_repeats = 4 + num_outputs = num_repeats * 2 + self.run_core_tests(lambda: self._build_dataset(num_repeats), + lambda: self._build_dataset(num_repeats // 2), + num_outputs) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 5c74ed6ae7210e8e22efb6e8fdb773397459ce1e..17b6644759e53f84b23e070a71267aa15dcffe49 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops @@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class StatsDatasetTest(test.TestCase): +class StatsDatasetTestBase(test.TestCase): def _assertSummaryHasCount(self, summary_str, tag, expected_value): summary_proto = summary_pb2.Summary() @@ -49,6 +50,9 @@ class StatsDatasetTest(test.TestCase): return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + +class StatsDatasetTest(StatsDatasetTestBase): + def testBytesProduced(self): stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).map( @@ -193,6 +197,45 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) +class FeatureStatsDatasetTest( + StatsDatasetTestBase, + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): + + def testFeaturesStats(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + batch_size = 2 + stats_aggregator = stats_ops.StatsAggregator() + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5, + drop_final_batch=True).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(total_records // batch_size): + sess.run(next_element) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:features", total_records) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:feature-values", total_records) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:features", total_records * 3) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:feature-values", + self._sum_keywords(1) * num_epochs + 2 * total_records) + + class StatsDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 7a3e42cc72755c67b910db99c0238f6ba780a942..33b7a75046cf2acfa3d787833b907aa2b28dbdca 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -45,6 +45,27 @@ py_library( "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +py_test( + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", ], ) @@ -75,8 +96,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":batching", + ":gen_dataset_ops", ":interleave_ops", ":shuffle_ops", + ":stats_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", @@ -85,12 +108,12 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -121,6 +144,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", ], @@ -187,6 +211,20 @@ py_library( ], ) +py_library( + name = "optimization", + srcs = ["optimization.py"], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + py_library( name = "resampling", srcs = ["resampling.py"], @@ -347,6 +385,7 @@ py_library( ":get_single_element", ":grouping", ":interleave_ops", + ":optimization", ":prefetching_ops", ":readers", ":resampling", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index b9393de4e90ae2597045b29070934b94e18cfcbd..17256eb97211c294a6e31c7fd2e969d71e091fb3 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib.framework import with_shape from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -29,6 +30,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import deprecation def dense_to_sparse_batch(batch_size, row_shape): @@ -218,6 +220,8 @@ def filter_irregular_batches(batch_size): return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.") def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). @@ -250,12 +254,16 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time + # after 6/30/2018. batched = dataset.batch(batch_size) return filter_irregular_batches(batch_size)(batched) return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.padded_batch(..., drop_remainder=True)`.") def padded_batch_and_drop_remainder(batch_size, padded_shapes, padding_values=None): @@ -284,6 +292,8 @@ def padded_batch_and_drop_remainder(batch_size, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)` + # any time after 6/30/2018. batched = dataset.padded_batch( batch_size, padded_shapes=padded_shapes, padding_values=padding_values) return filter_irregular_batches(batch_size)(batched) @@ -309,7 +319,7 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): return gen_dataset_ops.dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, - row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access + row_shape=convert.partial_shape_to_tensor(self._row_shape), output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes)), output_types=nest.flatten( diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index ea229b5b27b117984e508fa4edc6f1cf713008b4..520f78422839af638ccc855fded9fd436d4a8b49 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -300,6 +300,7 @@ class GroupByReducerDataset(dataset_ops.Dataset): raise ValueError( "`key_func` must return a single tf.int64 tensor. " "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access return ret self._key_func = tf_key_func @@ -327,6 +328,8 @@ class GroupByReducerDataset(dataset_ops.Dataset): self._state_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access + # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) @@ -398,6 +401,8 @@ class GroupByReducerDataset(dataset_ops.Dataset): nest.pack_sequence_as(self._state_types, [t.dtype for t in flat_new_state]))) + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access + # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, @@ -464,6 +469,8 @@ class GroupByReducerDataset(dataset_ops.Dataset): self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access + # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) @@ -525,6 +532,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): if window_size.dtype != dtypes.int64: raise ValueError( "`window_size_func` must return a single tf.int64 tensor.") + dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access return window_size self._window_size_func = tf_window_size_func @@ -557,6 +565,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) if ret.dtype != dtypes.int64: raise ValueError("`key_func` must return a single tf.int64 tensor.") + dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access return ret self._key_func = tf_key_func @@ -580,6 +589,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): self._output_classes = output_dataset.output_classes self._output_types = output_dataset.output_types self._output_shapes = output_dataset.output_shapes + dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access return output_dataset._as_variant_tensor() # pylint: disable=protected-access self._reduce_func = tf_reduce_func diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 812a50ecbf105393f7e422edbbdf5c87311d72c1..be66fbac50753c8f54b62dd615ee60804f4cf20d 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation @@ -240,3 +241,47 @@ def sample_from_datasets(datasets, weights=None, seed=None): (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) return DirectedInterleaveDataset(selector_input, datasets) + + +def choose_from_datasets(datasets, choice_dataset): + """Creates a dataset that deterministically chooses elements from `datasets`. + + For example, given the following datasets: + + ```python + datasets = [tf.data.Dataset.from_tensors("foo").repeat(), + tf.data.Dataset.from_tensors("bar").repeat(), + tf.data.Dataset.from_tensors("baz").repeat()] + + # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. + choice_dataset = tf.data.Dataset.range(3).repeat(3) + + result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset) + ``` + + The elements of `result` will be: + + ``` + "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" + ``` + + Args: + datasets: A list of @{tf.data.Dataset} objects with compatible structure. + choice_dataset: A @{tf.data.Dataset} of scalar `tf.int64` tensors between + `0` and `len(datasets) - 1`. + + Returns: + A dataset that interleaves elements from `datasets` according to the values + of `choice_dataset`. + + Raises: + TypeError: If the `datasets` or `choice_dataset` arguments have the wrong + type. + """ + if not (choice_dataset.output_types == dtypes.int64 + and choice_dataset.output_shapes.is_compatible_with( + tensor_shape.scalar()) + and choice_dataset.output_classes == ops.Tensor): + raise TypeError("`choice_dataset` must be a dataset of scalar " + "`tf.int64` tensors.") + return DirectedInterleaveDataset(choice_dataset, datasets) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index d736029fb035e573b70e8b19570e4e8ceca3c005..0d71be66018eeebe60de9deff24ceb6854d209d9 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -16,10 +16,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training import saver +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import session_run_hook def make_saveable_from_iterator(iterator): @@ -60,14 +62,14 @@ def make_saveable_from_iterator(iterator): return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access -class _Saveable(saver.BaseSaverBuilder.SaveableObject): +class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject): """SaveableObject for saving/restoring iterator state.""" def __init__(self, iterator_resource): serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) specs = [ - saver.BaseSaverBuilder.SaveSpec(serialized_iterator, "", - iterator_resource.name + "-state") + saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "", + iterator_resource.name + "-state") ] super(_Saveable, self).__init__(iterator_resource, specs, iterator_resource.name) @@ -75,3 +77,182 @@ class _Saveable(saver.BaseSaverBuilder.SaveableObject): def restore(self, restored_tensors, unused_restored_shapes): with ops.colocate_with(self.op): return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) + + +class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): + """Checkpoints input pipeline state every N steps or seconds. + + This hook saves the state of the iterators in the `Graph` so that when + training is resumed the input pipeline continues from where it left off. + This could potentially avoid overfitting in certain pipelines where the + number of training steps per eval are small compared to the dataset + size or if the training pipeline is pre-empted. + + Differences from `CheckpointSaverHook`: + 1. Saves only the input pipelines in the "iterators" collection and not the + global variables or other saveable objects. + 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary. + + Example of checkpointing the training pipeline: + + ```python + est = tf.estimator.Estimator(model_fn) + while True: + est.train( + train_input_fn, + hooks=[tf.contrib.data.CheckpointInputPipelineHook(est)], + steps=train_steps_per_eval) + # Note: We do not pass the hook here. + metrics = est.evaluate(eval_input_fn) + if should_stop_the_training(metrics): + break + ``` + + This hook should be used if the input pipeline state needs to be saved + separate from the model checkpoint. Doing so may be useful for a few reasons: + 1. The input pipeline checkpoint may be large, if there are large shuffle + or prefetch buffers for instance, and may bloat the checkpoint size. + 2. If the input pipeline is shared between training and validation, restoring + the checkpoint during validation may override the validation input + pipeline. + + For saving the input pipeline checkpoint alongside the model weights use + @{tf.contrib.data.make_saveable_from_iterator} directly to create a + `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, + that you will need to be careful not to restore the training iterator during + eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS + collector when building the eval graph. + """ + + def __init__(self, estimator): + """Initializes a `CheckpointInputPipelineHook`. + + Args: + estimator: Estimator. + + Raises: + ValueError: One of `save_steps` or `save_secs` should be set. + ValueError: At most one of saver or scaffold should be set. + """ + # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or + # of the form "input__.ckpt" for distributed pipelines. + # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is + # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix + # to be different to avoid conflicts with the model checkpoint. + + # pylint: disable=protected-access + checkpoint_prefix = "input" + if estimator._config.num_worker_replicas > 1: + # Distributed setting. + suffix = "_{}_{}".format(estimator._config.task_type, + estimator._config.task_id) + checkpoint_prefix += suffix + # pylint: enable=protected-access + + # We use a composition paradigm instead of inheriting from + # `CheckpointSaverHook` because `Estimator` does an `isinstance` check + # to check whether a `CheckpointSaverHook` is already present in the list + # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` + # would thwart this behavior. This hook checkpoints *only the iterators* + # and not the graph variables. + self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( + estimator.model_dir, + save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access + save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access + checkpoint_basename=checkpoint_prefix + ".ckpt") + + # Name for the protocol buffer file that will contain the list of most + # recent checkpoints stored as a `CheckpointState` protocol buffer. + # This file, kept in the same directory as the checkpoint files, is + # automatically managed by the `Saver` to keep track of recent checkpoints. + # The default name used by the `Saver` for this file is "checkpoint". Here + # we use the name "checkpoint_" so that in case the + # `checkpoint_dir` is the same as the model checkpoint directory, there are + # no conflicts during restore. + self._latest_filename = "checkpoint_" + checkpoint_prefix + self._first_run = True + + def begin(self): + # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` + # collection if no `Saver` or `Scaffold` is provided. + # pylint: disable=protected-access + if (self._checkpoint_saver_hook._saver is None and + self._checkpoint_saver_hook._scaffold is None): + iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) + saveables = [_Saveable(i) for i in iterators] + self._checkpoint_saver_hook._saver = _CustomSaver(saveables, + self._latest_filename) + # pylint: enable=protected-access + self._checkpoint_saver_hook.begin() + + def _restore_or_save_initial_ckpt(self, session): + # Ideally this should be run in after_create_session but is not for the + # following reason: + # Currently there is no way of enforcing an order of running the + # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` + # is run *after* this hook. That is troublesome because + # 1. If a checkpoint exists and this hook restores it, the initializer hook + # will override it. + # 2. If no checkpoint exists, this hook will try to save an initialized + # iterator which will result in an exception. + # + # As a temporary fix we enter the following implicit contract between this + # hook and the _DatasetInitializerHook. + # 1. The _DatasetInitializerHook initializes the iterator in the call to + # after_create_session. + # 2. This hook saves the iterator on the first call to `before_run()`, which + # is guaranteed to happen after `after_create_session()` of all hooks + # have been run. + + # Check if there is an existing checkpoint. If so, restore from it. + # pylint: disable=protected-access + latest_checkpoint_path = saver_lib.latest_checkpoint( + self._checkpoint_saver_hook._checkpoint_dir, + latest_filename=self._latest_filename) + if latest_checkpoint_path: + self._checkpoint_saver_hook._get_saver().restore(session, + latest_checkpoint_path) + else: + # The checkpoint saved here is the state at step "global_step". + # Note: We do not save the GraphDef or MetaGraphDef here. + global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) + self._checkpoint_saver_hook._save(session, global_step) + self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) + # pylint: enable=protected-access + + def before_run(self, run_context): + if self._first_run: + self._restore_or_save_initial_ckpt(run_context.session) + self._first_run = False + return self._checkpoint_saver_hook.before_run(run_context) + + def after_run(self, run_context, run_values): + self._checkpoint_saver_hook.after_run(run_context, run_values) + + def end(self, session): + self._checkpoint_saver_hook.end(session) + + +class _CustomSaver(saver_lib.Saver): + """`Saver` with a different default `latest_filename`. + + This is used in the `CheckpointInputPipelineHook` to avoid conflicts with + the model ckpt saved by the `CheckpointSaverHook`. + """ + + def __init__(self, var_list, latest_filename): + super(_CustomSaver, self).__init__(var_list) + self._latest_filename = latest_filename + + def save(self, + sess, + save_path, + global_step=None, + latest_filename=None, + meta_graph_suffix="meta", + write_meta_graph=True, + write_state=True, + strip_default_attrs=False): + return super(_CustomSaver, self).save( + sess, save_path, global_step, latest_filename or self._latest_filename, + meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/ops/iterator_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..30a993b1f7056b9726f524b2279131339c80c5eb --- /dev/null +++ b/tensorflow/contrib/data/python/ops/iterator_ops_test.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for experimental iterator_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import training_util + + +class CheckpointInputPipelineHookTest(test.TestCase): + + @staticmethod + def _model_fn(features, labels, mode, config): + del labels + del mode + del config + global_step = training_util.get_or_create_global_step() + update_global_step_op = global_step.assign_add(1) + latest_feature = variables.Variable( + 0, name='latest_feature', dtype=dtypes.int64) + store_latest_feature_op = latest_feature.assign(features) + ops.add_to_collection('my_vars', global_step) + ops.add_to_collection('my_vars', latest_feature) + return model_fn.EstimatorSpec( + mode='train', + train_op=control_flow_ops.group( + [update_global_step_op, store_latest_feature_op]), + loss=constant_op.constant(2.0)) + + def _read_vars(self, model_dir): + """Returns (global_step, latest_feature).""" + with ops.Graph().as_default() as g: + ckpt_path = saver_lib.latest_checkpoint(model_dir) + meta_filename = ckpt_path + '.meta' + saver_lib.import_meta_graph(meta_filename) + saver = saver_lib.Saver() + with self.test_session(graph=g) as sess: + saver.restore(sess, ckpt_path) + return sess.run(ops.get_collection('my_vars')) + + def _build_iterator_saver_hook(self, est): + return iterator_ops.CheckpointInputPipelineHook(est) + + def testReturnDatasetFromInputFn(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testBuildIteratorInInputFn(self): + + def _input_fn(): + ds = dataset_ops.Dataset.range(10) + iterator = ds.make_one_shot_iterator() + return iterator.get_next() + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testDoNotRestore(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + # Hook not provided, input pipeline was not restored. + est.train(_input_fn, steps=2) + self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1)) + + def testRaiseErrorIfNoIterator(self): + + def _input_fn(): + return constant_op.constant(1, dtype=dtypes.int64) + + est = estimator.Estimator(model_fn=self._model_fn) + + with self.assertRaises(ValueError): + est.train( + _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..cad41bce2961f29a7591fe3d382d1ab35a6b38b4 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +def optimize(optimizations=None): + """A transformation that applies optimizations. + + Args: + optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying + optimizations to use. If not specified, the default set of optimizations + is applied. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return OptimizeDataset(dataset, optimizations) + + return _apply_fn + + +class OptimizeDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and applies optimizations.""" + + def __init__(self, input_dataset, optimizations): + """See `optimize()` for details.""" + super(OptimizeDataset, self).__init__() + self._input_dataset = input_dataset + if optimizations is None: + optimizations = [] + self._optimizations = ops.convert_to_tensor( + optimizations, dtype=dtypes.string, name="optimizations") + + def _as_variant_tensor(self): + return gen_dataset_ops.optimize_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._optimizations, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index bbb808fbd7730002e48cab47fa8d0fe09e2124d2..83095c7ba1c6465d18490e5197f71bf7f1fe2497 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,16 +17,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import csv -from math import ceil import numpy as np from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -34,9 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import string_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation @@ -68,7 +69,7 @@ def _is_valid_float(str_val, float_dtype): return False -def _infer_type(str_val, na_value, prev_type, float_dtype): +def _infer_type(str_val, na_value, prev_type): """Given a string, infers its tensor type. Infers the type of a value by picking the least 'permissive' type possible, @@ -79,29 +80,34 @@ def _infer_type(str_val, na_value, prev_type, float_dtype): na_value: Additional string to recognize as a NA/NaN CSV value. prev_type: Type previously inferred based on values of this column that we've seen up till now. - float_dtype: Either `tf.float32` or `tf.float64`. Denotes what float type - to parse float strings as. Returns: Inferred dtype. """ if str_val in ("", na_value): + # If the field is null, it gives no extra information about its type return prev_type - if _is_valid_int32(str_val) and prev_type in (None, dtypes.int32): - return dtypes.int32 + type_list = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string + ] # list of types to try, ordered from least permissive to most - if _is_valid_int64(str_val) and prev_type in (None, dtypes.int32, - dtypes.int64): - return dtypes.int64 + type_functions = [ + _is_valid_int32, + _is_valid_int64, + lambda str_val: _is_valid_float(str_val, dtypes.float32), + lambda str_val: _is_valid_float(str_val, dtypes.float64), + lambda str_val: True, + ] # Corresponding list of validation functions - if _is_valid_float(str_val, float_dtype) and prev_type != dtypes.string: - return float_dtype + for i in range(len(type_list)): + validation_fn = type_functions[i] + if validation_fn(str_val) and (prev_type is None or + prev_type in type_list[:i + 1]): + return type_list[i] - return dtypes.string - -def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, - comment): +def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header): + """Generator that yields rows of CSV file(s) in order.""" for fn in filenames: with file_io.FileIO(fn, "r") as f: rdr = csv.reader( @@ -112,9 +118,6 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, next(rdr) # Skip header lines for csv_row in rdr: - if comment is not None and csv_row[0].startswith(comment): - continue # Skip comment lines - if len(csv_row) != num_cols: raise ValueError( "Problem inferring types: CSV row has different number of fields " @@ -123,22 +126,21 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, - na_value, header, comment, float_dtype, - num_rows_for_inference, select_columns): + na_value, header, num_rows_for_inference, + select_columns): """Infers column types from the first N valid CSV records of files.""" if select_columns is None: select_columns = range(num_cols) inferred_types = [None] * len(select_columns) for i, csv_row in enumerate( - _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, - comment)): + _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)): if num_rows_for_inference is not None and i >= num_rows_for_inference: break for j, col_index in enumerate(select_columns): inferred_types[j] = _infer_type(csv_row[col_index], na_value, - inferred_types[j], float_dtype) + inferred_types[j]) # Replace None's with a default type inferred_types = [t or dtypes.string for t in inferred_types] @@ -198,6 +200,112 @@ def _get_sorted_col_indices(select_columns, column_names): return result +def _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed): + """Optionally shuffle and repeat dataset, as requested.""" + if num_epochs != 1 and shuffle: + # Use shuffle_and_repeat for perf + return dataset.apply( + shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, + shuffle_seed)) + elif shuffle: + return dataset.shuffle(shuffle_buffer_size, shuffle_seed) + elif num_epochs != 1: + return dataset.repeat(num_epochs) + return dataset + + +def make_tf_record_dataset( + file_pattern, + batch_size, + parser_fn=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=None, + shuffle_seed=None, + prefetch_buffer_size=None, + num_parallel_reads=None, + num_parallel_parser_calls=None, + drop_final_batch=False): + """Reads and optionally parses TFRecord files into a dataset. + + Provides common functionality such as batching, optional parsing, shuffling, + and performant defaults. + + Args: + file_pattern: List of files or patterns of TFRecord file paths. + See @{tf.gfile.Glob} for pattern rules. + batch_size: An int representing the number of records to combine + in a single batch. + parser_fn: (Optional.) A function accepting string input to parse + and process the record contents. This function must map records + to components of a fixed shape, so they may be batched. By + default, uses the record contents unmodified. + num_epochs: (Optional.) An int specifying the number of times this + dataset is repeated. If None (the default), cycles through the + dataset forever. + shuffle: (Optional.) A bool that indicates whether the input + should be shuffled. Defaults to `True`. + shuffle_buffer_size: (Optional.) Buffer size to use for + shuffling. A large buffer size ensures better shuffling, but + increases memory usage and startup time. + shuffle_seed: (Optional.) Randomization seed to use for shuffling. + prefetch_buffer_size: (Optional.) An int specifying the number of + feature batches to prefetch for performance improvement. + Defaults to auto-tune. Set to 0 to disable prefetching. + num_parallel_reads: (Optional.) Number of threads used to read + records from files. By default or if set to a value >1, the + results will be interleaved. + num_parallel_parser_calls: (Optional.) Number of parallel + records to parse in parallel. Defaults to an automatic selection. + drop_final_batch: (Optional.) Whether the last batch should be + dropped in case its size is smaller than `batch_size`; the + default behavior is not to drop the smaller batch. + + Returns: + A dataset, where each element matches the output of `parser_fn` + except it will have an additional leading `batch-size` dimension, + or a `batch_size`-length 1-D tensor of strings if `parser_fn` is + unspecified. + """ + files = dataset_ops.Dataset.list_files( + file_pattern, shuffle=shuffle, seed=shuffle_seed) + + if num_parallel_reads is None: + # Note: We considered auto-tuning this value, but there is a concern + # that this affects the mixing of records from different files, which + # could affect training convergence/accuracy, so we are defaulting to + # a constant for now. + num_parallel_reads = 24 + dataset = core_readers.TFRecordDataset( + files, num_parallel_reads=num_parallel_reads) + + if shuffle_buffer_size is None: + # TODO(josh11b): Auto-tune this value when not specified + shuffle_buffer_size = 10000 + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + + if parser_fn is None: + if drop_final_batch: + dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) + else: + dataset = dataset.batch(batch_size) + else: + # TODO(josh11b): if num_parallel_parser_calls is None, use some function + # of num cores instead of map_and_batch's default behavior of one batch. + dataset = dataset.apply(batching.map_and_batch( + parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, + drop_remainder=drop_final_batch)) + + if prefetch_buffer_size is None: + prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE + if prefetch_buffer_size == 0: + return dataset + else: + return dataset.prefetch(buffer_size=prefetch_buffer_size) + + def make_csv_dataset( file_pattern, batch_size, @@ -209,7 +317,6 @@ def make_csv_dataset( use_quote_delim=True, na_value="", header=True, - comment=None, num_epochs=None, shuffle=True, shuffle_buffer_size=10000, @@ -218,7 +325,6 @@ def make_csv_dataset( num_parallel_reads=1, num_parallel_parser_calls=2, sloppy=False, - default_float_type=dtypes.float32, num_rows_for_inference=100, ): """Reads CSV files into a dataset. @@ -231,8 +337,8 @@ def make_csv_dataset( Args: file_pattern: List of files or patterns of file paths containing CSV records. See @{tf.gfile.Glob} for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. column_names: An optional list of strings that corresponds to the CSV columns, in order. One per column of the input record. If this is not provided, infers the column names from the first row of the records. @@ -272,15 +378,11 @@ def make_csv_dataset( header: A bool that indicates whether the first rows of provided CSV files correspond to header lines with column names, and should not be included in the data. - comment: An optional character string that marks lines that should not be - parsed as csv records. If this is provided, all lines that start with - this character will not be parsed. num_epochs: An int specifying the number of times this dataset is repeated. If None, cycles through the dataset forever. shuffle: A bool that indicates whether the input should be shuffled. shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size - ensures better shuffling, but would increase memory usage and startup - time. + ensures better shuffling, but increases memory usage and startup time. shuffle_seed: Randomization seed to use for shuffling. prefetch_buffer_size: An int specifying the number of feature batches to prefetch for performance improvement. Recommended value is the number of @@ -294,8 +396,6 @@ def make_csv_dataset( produced is deterministic prior to shuffling (elements are still randomized if `shuffle=True`. Note that if the seed is set, then order of elements after shuffling is deterministic). Defaults to `False`. - default_float_type: Either `tf.float32` or `tf.float64`. If defaults are - not provided, float-like strings are interpreted to be this type. num_rows_for_inference: Number of rows of a file to use for type inference if record_defaults is not provided. If None, reads all the rows of all the files. Defaults to 100. @@ -317,8 +417,6 @@ def make_csv_dataset( dataset = dataset.shuffle(len(filenames), shuffle_seed) # Clean arguments; figure out column names and defaults - if comment is not None and len(comment) != 1: - raise ValueError("`comment` arg must be a single-character string or None") if column_names is None: if not header: @@ -341,8 +439,7 @@ def make_csv_dataset( # construction time column_defaults = _infer_column_defaults( filenames, len(column_names), field_delim, use_quote_delim, na_value, - header, comment, default_float_type, num_rows_for_inference, - select_columns) + header, num_rows_for_inference, select_columns) if select_columns is not None and len(column_defaults) != len(select_columns): raise ValueError( @@ -356,71 +453,189 @@ def make_csv_dataset( if label_name is not None and label_name not in column_names: raise ValueError("`label_name` provided must be one of the columns.") - # Define map and filter functions - def filter_fn(line): - return math_ops.not_equal(string_ops.substr(line, 0, 1), comment) - def filename_to_dataset(filename): - ds = core_readers.TextLineDataset(filename) - if header: - ds = ds.skip(1) - if comment is not None: - ds = ds.filter(filter_fn) - return ds + return CsvDataset( + filename, + record_defaults=column_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + na_value=na_value, + select_cols=select_columns, + header=header) - def decode_csv(line): - """Decodes CSV line into features. + def map_fn(*columns): + """Organizes columns into a features dictionary. Args: - line: String tensor corresponding to one csv record. + *columns: list of `Tensor`s corresponding to one csv record. Returns: - A dictionary of feature names to values for that particular record. If + An OrderedDict of feature names to values for that particular record. If label_name is provided, extracts the label feature to be returned as the second element of the tuple. """ - columns = parsing_ops.decode_csv( - line, - column_defaults, - field_delim=field_delim, - use_quote_delim=use_quote_delim, - na_value=na_value, - select_cols=select_columns, - ) - features = dict(zip(column_names, columns)) + features = collections.OrderedDict(zip(column_names, columns)) if label_name is not None: label = features.pop(label_name) return features, label return features - # Read files sequentially or in parallel + # Read files sequentially (if num_parallel_reads=1) or in parallel dataset = dataset.apply( interleave_ops.parallel_interleave( filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy)) - if num_epochs != 1 and shuffle: - # Use shuffle_and_repeat for perf - dataset = dataset.apply( - shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, - shuffle_seed)) - elif shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) - elif num_epochs != 1: - dataset = dataset.repeat(num_epochs) - - # Use map_and_batch for perf - # TODO(b/76425672): use num_parallel_calls for better performance tuning when - # that is added - dataset = dataset.apply( - batching.map_and_batch( - map_func=decode_csv, - batch_size=batch_size, - num_parallel_batches=int( - ceil(num_parallel_parser_calls / batch_size)))) + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + # Apply batch before map for perf, because map has high overhead relative + # to the size of the computation in each map + dataset = dataset.batch(batch_size=batch_size) + dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) dataset = dataset.prefetch(prefetch_buffer_size) + return dataset +_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB + + +class CsvDataset(dataset_ops.Dataset): + """A Dataset comprising lines from one or more CSV files.""" + + def __init__(self, + filenames, + record_defaults, + buffer_size=None, + header=False, + field_delim=",", + use_quote_delim=True, + na_value="", + select_cols=None): + """Creates a `CsvDataset` by reading and decoding CSV files. + + The elements of this dataset correspond to records from the file(s). + RFC 4180 format is expected for CSV files + (https://tools.ietf.org/html/rfc4180) + Note that we allow leading and trailing spaces with int or float field. + + + For example, suppose we have a file 'my_file0.csv' with four CSV columns of + different data types: + ``` + abcdefg,4.28E10,5.55E6,12 + hijklmn,-5.3E14,,2 + ``` + + We can construct a CsvDataset from it as follows: + ```python + dataset = tf.contrib.data.CsvDataset( + "my_file*.csv", + [tf.float32, # Required field, use dtype or empty tensor + tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 + tf.int32, # Required field, use dtype or empty tensor + ], + select_cols=[1,2,3] # Only parse last three columns + ) + ``` + + The expected output of its iterations is: + ```python + next = dataset.make_one_shot_iterator().get_next() + with tf.Session() as sess: + while True: + try: + print(sess.run(nxt)) + except tf.errors.OutOfRangeError: + break + + >> (4.28e10, 5.55e6, 12) + >> (-5.3e14, 0.0, 2) + ``` + + Args: + filenames: A `tf.string` tensor containing one or more filenames. + record_defaults: A list of default values for the CSV fields. Each item in + the list is either a valid CSV `DType` (float32, float64, int32, int64, + string), or a `Tensor` object with one of the above types. One per + column of CSV data, with either a scalar `Tensor` default value for the + column if it is optional, or `DType` or empty `Tensor` if required. If + both this and `select_columns` are specified, these must have the same + lengths, and `column_defaults` is assumed to be sorted in order of + increasing column index. + buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes + to buffer while reading files. Defaults to 4MB. + header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) + have header line(s) that should be skipped when parsing. Defaults to + `False`. + field_delim: (Optional.) A `tf.string` scalar containing the delimiter + character that separates fields in a record. Defaults to `","`. + use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats + double quotation marks as regular characters inside of string fields + (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`. + na_value: (Optional.) A `tf.string` scalar indicating a value that will + be treated as NA/NaN. + select_cols: (Optional.) A sorted list of column indices to select from + the input data. If specified, only this subset of columns will be + parsed. Defaults to parsing all columns. + """ + super(CsvDataset, self).__init__() + self._filenames = ops.convert_to_tensor( + filenames, dtype=dtypes.string, name="filenames") + record_defaults = [ + constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x + for x in record_defaults + ] + self._record_defaults = ops.convert_n_to_tensor( + record_defaults, name="record_defaults") + self._buffer_size = convert.optional_param_to_tensor( + "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) + self._header = ops.convert_to_tensor( + header, dtype=dtypes.bool, name="header") + self._field_delim = ops.convert_to_tensor( + field_delim, dtype=dtypes.string, name="field_delim") + self._use_quote_delim = ops.convert_to_tensor( + use_quote_delim, dtype=dtypes.bool, name="use_quote_delim") + self._na_value = ops.convert_to_tensor( + na_value, dtype=dtypes.string, name="na_value") + self._select_cols = convert.optional_param_to_tensor( + "select_cols", + select_cols, + argument_default=[], + argument_dtype=dtypes.int64, + ) + self._output_shapes = tuple( + tensor_shape.scalar() for _ in range(len(record_defaults))) + self._output_types = tuple(d.dtype for d in self._record_defaults) + self._output_classes = tuple( + ops.Tensor for _ in range(len(record_defaults))) + + def _as_variant_tensor(self): + # Constructs graph node for the dataset op. + return contrib_gen_dataset_ops.csv_dataset( + filenames=self._filenames, + record_defaults=self._record_defaults, + buffer_size=self._buffer_size, + header=self._header, + output_shapes=self._output_shapes, + field_delim=self._field_delim, + use_quote_delim=self._use_quote_delim, + na_value=self._na_value, + select_cols=self._select_cols, + ) + + @property + def output_types(self): + return self._output_types + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_classes(self): + return self._output_classes + + def make_batched_features_dataset(file_pattern, batch_size, features, @@ -480,8 +695,8 @@ def make_batched_features_dataset(file_pattern, Args: file_pattern: List of files or patterns of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. reader: A function or class that can be @@ -537,16 +752,10 @@ def make_batched_features_dataset(file_pattern, dataset = dataset.map(lambda _, v: v) # Apply dataset repeat and shuffle transformations. - repeat_dataset = (num_epochs != 1) - if repeat_dataset and shuffle: - # Used fused shuffle_and_repeat operation for better performance - dataset = dataset.apply( - shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, - shuffle_seed)) - elif repeat_dataset: - dataset = dataset.repeat(num_epochs) - elif shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + + dataset = dataset.apply(stats_ops.feature_stats("record_stats")) if drop_final_batch: dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) @@ -620,8 +829,8 @@ def read_batch_features(file_pattern, Args: file_pattern: List of files or patterns of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. reader: A function or class that can be diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index e911ad0fa0541f2d8b991d66182dd002c2ecaab0..9909ca8d9d634955dacb325ccc32d2e3650ca534 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -148,6 +148,8 @@ class _ScanDataset(dataset_ops.Dataset): self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in nest.flatten(output_value)]) + dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access + # Serialize any sparse tensors. new_state = nest.pack_sequence_as(new_state, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 3cbaab5affd7397213b0fbb6b0682db92b99d591..8c30202ba775527642d9904ca1adc226b5fdb7b9 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -176,6 +176,27 @@ def latency_stats(tag): return _apply_fn +def feature_stats(tag): + """Records the features stats from `Example` records of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with the output + dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will be + associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.feature_stats_dataset, tag) + + return _apply_fn + + class _StatsDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and also records statistics.""" diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 8dfcaf6032e1602ed76a8a995553c5d398c4a778..9dfb8552f1b0f058b44f8ed09c2ed681367293d5 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -26,7 +26,6 @@ py_library( "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/contrib/eager/python:datasets", "//tensorflow/python:array_ops", - "//tensorflow/python:checkpointable", "//tensorflow/python:control_flow_ops", "//tensorflow/python:device_util", "//tensorflow/python:distribute", @@ -34,6 +33,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:base", "@six_archive//:six", ], ) @@ -77,6 +77,7 @@ py_library( "//tensorflow/python:device_util", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -148,9 +149,11 @@ py_library( ], deps = [ ":mirrored_strategy", + ":multi_worker_strategy", ":one_device_strategy", ":tpu_strategy", "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:util", @@ -444,10 +447,31 @@ py_library( srcs = ["cross_tower_utils.py"], srcs_version = "PY2AND3", deps = [ + ":values", + "//tensorflow/contrib/all_reduce:all_reduce_py", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + ], +) + +cuda_py_test( + name = "cross_tower_utils_test", + srcs = ["cross_tower_utils_test.py"], + additional_deps = [ + ":combinations", + ":cross_tower_utils", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "no_pip", ], ) @@ -469,24 +493,26 @@ py_library( ], ) -py_test( +cuda_py_test( name = "cross_tower_ops_test", srcs = ["cross_tower_ops_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ + additional_deps = [ ":combinations", ":cross_tower_ops", + ":multi_worker_test_base", ":values", + "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", - "@absl_py//absl/testing:parameterized", + ], + shard_count = 15, + tags = [ + "multi_and_single_gpu", + "no_pip", ], ) @@ -546,3 +572,41 @@ cuda_py_test( "no_pip", ], ) + +cuda_py_test( + name = "keras_test", + srcs = ["keras_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", + "//tensorflow/python/estimator:keras", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/keras", + ], + tags = [ + "multi_and_single_gpu", + "noguitar", + "notsan", + ], +) + +cuda_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 45d191127ee7349a59a7e3efa29baeda6445c44a..ba03b14deb9a3897dae29382ce601c0319f84735 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -41,16 +41,21 @@ from __future__ import print_function from collections import OrderedDict import sys +import types +import unittest from absl.testing import parameterized +import six -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import one_device_strategy -from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib +from tensorflow.contrib.distribute.python import multi_worker_strategy +from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib +from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.training import adam +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import gradient_descent from tensorflow.python.util import tf_inspect @@ -66,29 +71,35 @@ def generate(combinations): combinations: a list of dictionaries created using combine() and times(). Restrictions: - -- there should always be a "mode" argument. Accepted values are "eager" - and "graph". + -- the "mode" argument can be either "eager" or "graph". It's "graph" by + default. -- arguments of the test method must match by name to get the corresponding - value of the combination. Tests must accept all arguments (except "mode", - which is optional). - -- distribution argument is special. It is meant for passing instances of - DistributionStrategy. Each instance is to be passed as `(, - )` tuple, where is the number of required - GPUs. If the required number of GPUs for the DistributionStrategy isn't - available then the test case is going to be skipped. + value of the combination. Tests must accept all arguments except the + "mode", "required_tpu" and "required_gpus". + -- "distribution" argument is special and optional. It is meant for passing + instances of DistributionStrategy. Each instance is to be passed as via + `NamedDistribution`. If using "distribution", "required_gpus" and + "required_tpu" should be specified via the NamedDistribution instance, + rather than as separate arguments. + -- "required_tpu" argument is special and optional. If not `None`, then the + test will be skipped if TPUs aren't available. + -- "required_gpus" argument is special and optional. If not `None`, then the + test will be skipped if the specified number of GPUs aren't available. Returns: - a decorator that will cause the test method to be run under the specified - conditions. + a decorator that will cause the test method or the test class to be run + under the specified conditions. Raises: - ValueError - if "mode" argument wasn't either "eager" or "graph. + ValueError - if "mode" argument wasn't either "eager" or "graph" or if other + arguments were not accepted by the test method. """ - def decorator(test_function): + def decorator(test_method_or_class): """The decorator to be returned.""" # Generate good test names that can be used with --test_filter. + named_combinations = [] for combination in combinations: # We use OrderedDicts in `combine()` and `times()` to ensure stable # order of keys in each dictionary. @@ -99,59 +110,96 @@ def generate(combinations): "".join(filter(str.isalnum, str(value)))) for key, value in combination.items() ]) - combination.update({"testcase_name": "_test{}".format(name)}) - - @parameterized.named_parameters(*combinations) - def decorated(self, **kwargs): - """A wrapped test method that sets up `test_function`.""" - assert "mode" in kwargs - mode = kwargs["mode"] - - if "distribution" in kwargs: - distribution = kwargs["distribution"] - kwargs["distribution"] = distribution.strategy - if distribution.required_tpu and not TPU_TEST: - self.skipTest("Test requires a TPU, but it's not available.") - if not distribution.required_tpu and TPU_TEST: - self.skipTest("Test that doesn't require a TPU.") - - if not distribution.required_gpus: - if GPU_TEST: - self.skipTest("Test that doesn't require GPUs.") - elif context.num_gpus() < distribution.required_gpus: - self.skipTest( - "{} GPUs are not available for this test. {} GPUs are available". - format(distribution.required_gpus, context.num_gpus())) - - requested_arguments = tf_inspect.getfullargspec(test_function).args - missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( - set(requested_arguments + ["mode"])) - if missing_arguments: - raise ValueError("The test is missing arguments {} .".format( - missing_arguments)) - - kwargs_to_pass = {} - for arg in requested_arguments: - if arg == "self": - kwargs_to_pass[arg] = self - else: - kwargs_to_pass[arg] = kwargs[arg] - - if mode == "eager": - with context.eager_mode(), ops.Graph().as_default(): - test_function(**kwargs_to_pass) - elif mode == "graph": - with context.graph_mode(), ops.Graph().as_default(): - test_function(**kwargs_to_pass) - else: - raise ValueError( - "'mode' has to be either 'eager' or 'graph' and not {}".format( - mode)) + named_combinations.append( + OrderedDict( + list(combination.items()) + [("testcase_name", + "_test{}".format(name))])) + + if isinstance(test_method_or_class, type): + class_object = test_method_or_class + class_object._test_method_ids = test_method_ids = {} + for name, test_method in six.iteritems(class_object.__dict__.copy()): + if (name.startswith(unittest.TestLoader.testMethodPrefix) and + isinstance(test_method, types.FunctionType)): + delattr(class_object, name) + methods = {} + parameterized._update_class_dict_for_param_test_case( + class_object.__name__, methods, test_method_ids, name, + parameterized._ParameterizedTestIter( + _augment_with_special_arguments(test_method), + named_combinations, parameterized._NAMED, name)) + for method_name, method in six.iteritems(methods): + setattr(class_object, method_name, method) + + return class_object + else: + test_method = _augment_with_special_arguments(test_method_or_class) + return parameterized.named_parameters(*named_combinations)(test_method) - return decorated return decorator +def _augment_with_special_arguments(test_method): + def decorated(self, **kwargs): + """A wrapped test method that treats some arguments in a special way.""" + mode = kwargs.pop("mode", "graph") + + distribution = kwargs.pop("distribution", None) + required_tpu = kwargs.pop("required_tpu", False) + required_gpus = kwargs.pop("required_gpus", None) + + if distribution: + assert required_gpus is None, ( + "Do not use `required_gpus` and `distribution` together.") + assert required_tpu is False, ( + "Do not use `required_tpu` and `distribution` together.") + kwargs["distribution"] = distribution.strategy + required_gpus = distribution.required_gpus + required_tpu = distribution.required_tpu + + if required_tpu and not TPU_TEST: + self.skipTest("Test requires a TPU, but it's not available.") + if not required_tpu and TPU_TEST: + self.skipTest("Test that doesn't require a TPU.") + + if not required_gpus: + if GPU_TEST: + self.skipTest("Test that doesn't require GPUs.") + elif context.num_gpus() < required_gpus: + self.skipTest( + "{} GPUs are not available for this test. {} GPUs are available". + format(required_gpus, context.num_gpus())) + + # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu` + # that the user might have specified. `kwargs` still has `mode`, which + # the test is allowed to accept or ignore. + requested_arguments = tf_inspect.getfullargspec(test_method).args + missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( + set(requested_arguments + ["mode"])) + if missing_arguments: + raise ValueError("The test is missing arguments {} .".format( + missing_arguments)) + + kwargs_to_pass = {} + for arg in requested_arguments: + if arg == "self": + kwargs_to_pass[arg] = self + else: + kwargs_to_pass[arg] = kwargs[arg] + + if mode == "eager": + with ops.Graph().as_default(), context.eager_mode(): + test_method(**kwargs_to_pass) + elif mode == "graph": + with ops.Graph().as_default(), context.graph_mode(): + test_method(**kwargs_to_pass) + else: + raise ValueError( + "'mode' has to be either 'eager' or 'graph' and not {}".format( + mode)) + return decorated + + def combine(**kwargs): """Generate combinations based on its keyword arguments. @@ -159,7 +207,8 @@ def combine(**kwargs): can be computed using `times()`. Args: - **kwargs: keyword arguments of form `option=[possibilities, ...]`. + **kwargs: keyword arguments of form `option=[possibilities, ...]` + or `option=the_only_possibility`. Returns: a list of dictionaries for each combination. Keys in the dictionaries are @@ -178,6 +227,8 @@ def combine(**kwargs): key = first[0] values = first[1] + if not isinstance(values, list): + values = [values] return [ OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) @@ -239,9 +290,9 @@ class NamedObject(object): class NamedDistribution(object): """Translates DistributionStrategy and its data into a good name.""" - def __init__(self, name, distribution, required_gpus=None, + def __init__(self, name, distribution_fn, required_gpus=None, required_tpu=False): - self._distribution = distribution + self._distribution_fn = distribution_fn self._name = name self._required_gpus = required_gpus self._required_tpu = required_tpu @@ -251,7 +302,7 @@ class NamedDistribution(object): @property def strategy(self): - return self._distribution + return self._distribution_fn() @property def required_gpus(self): @@ -262,25 +313,59 @@ class NamedDistribution(object): return self._required_tpu +# pylint: disable=g-long-lambda +default_strategy = NamedDistribution( + "Default", + lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + required_gpus=None) one_device_strategy = NamedDistribution( - "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), - None) + "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), + required_gpus=None) tpu_strategy_single_iteration = NamedDistribution( "TPUSingleIteration", - tpu_strategy.TPUStrategy(iterations_per_step=1), + lambda: tpu_lib.TPUStrategy(iterations_per_step=1), required_tpu=True) -tpu_strategy = NamedDistribution( - "TPU", tpu_strategy.TPUStrategy(), required_tpu=True) +tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) +# Note that we disable prefetching for testing since prefetching makes +# the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1) -mirrored_strategy_without_prefetch = NamedDistribution( - "MirroredCPUAndGPUNoPrefetch", - mirrored_strategy.MirroredStrategy( - ["/gpu:0", "/cpu:0"], prefetch_on_device=False), 1) + lambda: mirrored_lib.MirroredStrategy( + ["/gpu:0", "/cpu:0"], prefetch_on_device=False), + required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - mirrored_strategy.MirroredStrategy(["/gpu:0", "/gpu:1"]), 2) + lambda: mirrored_lib.MirroredStrategy( + ["/gpu:0", "/gpu:1"], prefetch_on_device=False), + required_gpus=2) + +multi_worker_strategy_with_cpu = NamedDistribution( + "MultiWorkerCPU", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=0), 0) +multi_worker_strategy_with_one_gpu = NamedDistribution( + "MultiWorker1GPU", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=1), 1) +multi_worker_strategy_with_two_gpus = NamedDistribution( + "MultiWorker2GPUs", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=2), 2) adam_optimizer_v1_fn = NamedObject( "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py index 219b24160f3902fcfa5363cc39a8fc5b30d00308..86aa48cea889c6c2ce169b18bcabb6d08890fbed 100644 --- a/tensorflow/contrib/distribute/python/combinations_test.py +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from collections import OrderedDict +from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.python.eager import test @@ -41,6 +42,15 @@ class TestingCombinationsTest(test.TestCase): "b": 3 }], combinations.combine(a=[1, 2], b=[2, 3])) + def test_combine_single_parameter(self): + self.assertEqual([{ + "a": 1, + "b": 2 + }, { + "a": 2, + "b": 2 + }], combinations.combine(a=[1, 2], b=2)) + def test_add(self): self.assertEqual( [{ @@ -111,5 +121,28 @@ class TestingCombinationsTest(test.TestCase): _ = combinations.times(c1, c2) +@combinations.generate(combinations.combine(a=[1, 0], b=[2, 3], c=[1])) +class CombineTheTestSuite(parameterized.TestCase): + + def test_add_things(self, a, b, c): + self.assertLessEqual(3, a + b + c) + self.assertLessEqual(a + b + c, 5) + + def test_add_things_one_more(self, a, b, c): + self.assertLessEqual(3, a + b + c) + self.assertLessEqual(a + b + c, 5) + + def not_a_test(self, a=0, b=0, c=0): + del a, b, c + self.fail() + + def _test_but_private(self, a=0, b=0, c=0): + del a, b, c + self.fail() + + # Check that nothing funny happens to a non-callable that starts with "_test". + test_member = 0 + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index c6a1bf6a9f65828c45617ae18a1b0989f9d46225..f8ae8b9712c392fa948c8598dd123cdea01d9866 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import six from tensorflow.contrib.distribute.python import cross_tower_utils @@ -77,12 +78,12 @@ def _all_devices_match(value_destination_pairs): return True -def _simple_broadcast(tensor, destinations): +def _simple_broadcast(value, destinations): index = {} devices = _get_devices_from(destinations) for d in devices: - with ops.device(d): - index[d] = array_ops.identity(tensor) + index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + value, d) return value_lib.Mirrored(index) @@ -98,7 +99,9 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, continue count += len(v_list) # Sum within each device before aggregating across devices. - v = math_ops.add_n(v_list) + # TODO(yuefengz): Check whether it helps to use accumulation_fn here. + v = cross_tower_utils.aggregate_tensors_or_indexed_slices( + v_list, math_ops.add_n) else: count += 1 all_values.append(v) @@ -107,11 +110,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, with ops.device(reduce_to_device): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - if method_string == "sum": - reduced = accumulation_fn(all_values) - elif method_string == "mean": - reduced = accumulation_fn(all_values) / count - else: + reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( + all_values, accumulation_fn) + if method_string == "mean": + reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( + reduced, count) + elif method_string != "sum": raise ValueError("`method_string` must be 'sum' or 'mean'") return reduced @@ -231,7 +235,13 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): def _group_value_by_device(per_device_values): """Group values into sublists by their devices. - This grouping is needed to call the all-reduce library. + This grouping is needed to call the all-reduce library because it expects a + list of the following form: + [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ... + (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ... + (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ... + ... + ] Args: per_device_values: a list of PerDevice obejcts. @@ -319,7 +329,17 @@ class ConcatAndSplitPacker(object): # TODO(zhengxq): it is also possible to optimize away all the concat # as well. num_splits = self.num_packs - total_grad_size = array_ops.size(concat_grads) + + # The array_ops.size function will sometimes remove static shapes. So if + # all gradient shapes are defined, we use another method to get the + # total size. + # TODO(yuefengz): move this logic to array_ops.size. + if all([g.shape.is_fully_defined() for g, _ in tower_grads_and_vars]): + total_grad_size = sum( + [g.shape.num_elements() for g, _ in tower_grads_and_vars]) + else: + total_grad_size = array_ops.size(concat_grads) + split_size = total_grad_size // num_splits split_size_last = total_grad_size - split_size * (num_splits - 1) split_sizes = [split_size] * (num_splits - 1) + [split_size_last] @@ -409,6 +429,31 @@ class AggregateSmallTensorPacker(object): self.packing) +def _pack_tensors(device_grads, + num_packs=0, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=0): + """Pack tensors if specified.""" + if num_packs > 0: + tensor_packer = ConcatAndSplitPacker(num_packs) + device_grad_packs = tensor_packer.pack(device_grads) + elif agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0: + tensor_packer = AggregateSmallTensorPacker(agg_small_grads_max_bytes, + agg_small_grads_max_group) + device_grad_packs = tensor_packer.pack(device_grads) + else: + tensor_packer = None + device_grad_packs = device_grads + return device_grad_packs, tensor_packer + + +def _unpack_tensors(reduced, tensor_packer=None): + """Unpack tensors if they are packed before all-reduce.""" + if tensor_packer: + return tensor_packer.unpack(reduced) + return reduced + + class AllReduceCrossTowerOps(CrossTowerOps): """Reduction using all reduce.""" @@ -437,17 +482,25 @@ class AllReduceCrossTowerOps(CrossTowerOps): agg_small_grads_max_group: see above. tensors. """ - self.all_reduce_alg = all_reduce_alg - self.num_packs = num_packs - self.agg_small_grads_max_bytes = agg_small_grads_max_bytes - self.agg_small_grads_max_group = agg_small_grads_max_group + self._all_reduce_alg = all_reduce_alg + self._num_packs = num_packs + self._agg_small_grads_max_bytes = agg_small_grads_max_bytes + self._agg_small_grads_max_group = agg_small_grads_max_group super(AllReduceCrossTowerOps, self).__init__() def _reduce(self, method_string, per_device_value, destinations): + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + per_device_value) if ((destinations is None or _devices_match(per_device_value, destinations)) - and not context.executing_eagerly()): + and not context.executing_eagerly() + and not contains_indexed_slices): return self._batch_all_reduce(method_string, [per_device_value])[0] else: + if contains_indexed_slices: + logging.log_first_n( + logging.WARN, + "Efficient allreduce is not supported for IndexedSlices.", 10) + devices = _get_devices_from(destinations or per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, @@ -455,14 +508,18 @@ class AllReduceCrossTowerOps(CrossTowerOps): return self.broadcast(reduced, devices) def _batch_reduce(self, method_string, value_destination_pairs): - if (_all_devices_match(value_destination_pairs) and - not context.executing_eagerly()): + all_devices_match = _all_devices_match(value_destination_pairs) + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + value_destination_pairs) + if (all_devices_match and not context.executing_eagerly() + and not contains_indexed_slices): return self._batch_all_reduce(method_string, [v[0] for v in value_destination_pairs]) else: - if not context.executing_eagerly(): + if not all_devices_match: logging.warning("Efficient batch_reduce is not supported if " "destinations are different.") + return [ self._reduce(method_string, t, destinations=v) for t, v in value_destination_pairs @@ -470,37 +527,24 @@ class AllReduceCrossTowerOps(CrossTowerOps): def _batch_all_reduce(self, method_string, per_device_values): """All reduce algorithm in a batch.""" + logging.info( + "batch_all_reduce invoked for batches size = %d with " + "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " + "agg_small_grads_max_group = %d", len(per_device_values), + self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) destinations = per_device_values[0].devices grouped = _group_value_by_device(per_device_values) - if self.num_packs > 0: - logging.info( - "batch_all_reduce invoked for batches size = %d with " - "algorithm = %s and num_packs = %d", len(per_device_values), - self.all_reduce_alg, self.num_packs) - tensor_packer = ConcatAndSplitPacker(self.num_packs) - device_grad_packs = tensor_packer.pack(grouped) - elif (self.agg_small_grads_max_bytes > 0 and - self.agg_small_grads_max_group > 0): - logging.info( - "batch_all_reduce invoked for batches size = %d with " - "algorithm = %s, agg_small_grads_max_bytes = %d and " - "agg_small_grads_max_group = %d", len(per_device_values), - self.all_reduce_alg, self.agg_small_grads_max_bytes, - self.agg_small_grads_max_group) - tensor_packer = AggregateSmallTensorPacker( - self.agg_small_grads_max_bytes, self.agg_small_grads_max_group) - device_grad_packs = tensor_packer.pack(grouped) - else: - logging.info( - "batch_all_reduce invoked for batches size = %d with algorithm = %s", - len(per_device_values), self.all_reduce_alg) - tensor_packer = None - device_grad_packs = grouped + + device_grad_packs, self._tensor_packer = _pack_tensors( + grouped, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) # The actual aggregation of the repacked gradients. Note that they are # sharded among different aggregation trees. So it is important to strike # the balance on num_splits. - if self.all_reduce_alg == "nccl": + if self._all_reduce_alg == "nccl": + # TODO(yuefengz): merge this into the all-reduce library. reduced = cross_tower_utils.aggregate_gradients_using_nccl( device_grad_packs) else: @@ -510,13 +554,137 @@ class AllReduceCrossTowerOps(CrossTowerOps): cross_tower_utils.aggregate_gradients_using_hierarchical_copy( destinations, device_grad_packs)) - if tensor_packer: - reduced = tensor_packer.unpack(reduced) - + reduced = _unpack_tensors(reduced, self._tensor_packer) return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, method_string) +AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", + "alg shards limit") + + +class MultiWorkerAllReduce(AllReduceCrossTowerOps): + """All-reduce algorithms for distributed TensorFlow.""" + + def __init__(self, + worker_devices, + num_gpus_per_worker, + all_reduce_spec=("pscpu/pscpu", 2, -1), + num_packs=0, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=10): + """Initialize the all-reduce algorithm. + + Args: + worker_devices: a list of device strings for workers participating in + all-reduce. + num_gpus_per_worker: number of GPU devices per worker. + all_reduce_spec: a tuple or a named tuple or a list of tuples specifying + the all-reduce algorithm. + 1. The first element of a tuple is the name of the all-reduce algorithm. + Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd", + "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with + a "/" are hierarchical, so two all-reduces are executed, the first one + aggregates tensors within a worker and the second aggregates across + workers. + 2. The second element of a tuple is the number of shards when doing + all-reduce. Let's say its values is M, each tensor after packing will be + split into M shards and then M parallel all-reduces would be performed + before finally they are concatenated backed into a complete tensor. + 3. The third element is the maximum size of tensors that will be + applicable for the algorithm specified by the first element. For + example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)], + tensors with size not larger than 1024 bytes will be applied a 2-shard + "nccl" all-reduce and other tensors will be applied a 2-shard + "pscpu/pscpu" algorithm. The third elements should be in increasing + order across tuples and end with -1 which indicates infinity. + num_packs: see AllReduceCrossTowerOps. + agg_small_grads_max_bytes: see AllReduceCrossTowerOps. + agg_small_grads_max_group: see AllReduceCrossTowerOps. + """ + self._worker_devices = worker_devices + self._num_gpus_per_worker = num_gpus_per_worker + super(MultiWorkerAllReduce, self).__init__( + num_packs=num_packs, + agg_small_grads_max_bytes=agg_small_grads_max_bytes, + agg_small_grads_max_group=agg_small_grads_max_group) + + def validate_and_complete_spec(spec): + """Validate and complete the all-reduce spec.""" + # TODO(yuefengz): support namedtuple. + if not isinstance(spec, tuple): + raise ValueError( + "A tuple is expected for all-reduce spec: %r" % all_reduce_spec) + if not spec or len(spec) > 3: + raise ValueError( + "Too many elements in the all-reduce spec tuple: %r" % spec) + if len(spec) == 1: + return AllReduceSpecTuple(spec[0], 1, -1) + elif len(spec) == 2: + return AllReduceSpecTuple(spec[0], spec[1], -1) + else: + return AllReduceSpecTuple(*spec) + + self._all_reduce_spec = [] + if isinstance(all_reduce_spec, six.string_types): + self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1)) + elif isinstance(all_reduce_spec, tuple): + self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec)) + elif isinstance(all_reduce_spec, list): + self._all_reduce_spec = [ + validate_and_complete_spec(spec) for spec in all_reduce_spec + ] + + def _batch_all_reduce(self, method_string, per_device_values): + """All reduce algorithm in a batch.""" + logging.info( + "distributed batch_all_reduce invoked for batches size = %d with " + "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " + "and agg_small_grads_max_group = %d", len(per_device_values), + self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + + destinations = sorted(per_device_values[0].devices) + device_grads = _group_value_by_device(per_device_values) + + # The all reduce library requires fully defined shapes. + # TODO(yuefengz): when tensor sharding is not needed, static shapes are not + # required as well. + for device_grad in device_grads: + for grad, _ in device_grad: + if not grad.shape.is_fully_defined(): + raise ValueError("Shape is unknown for node %r" % grad) + + remaining_grads = device_grads + aggregated_grads = [] + for spec_tuple in self._all_reduce_spec: + if spec_tuple.limit < 0: + this_grads = remaining_grads + remaining_grads = [] + else: + (this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size( + spec_tuple.limit, remaining_grads) + if this_grads: + device_grad_packs, self._tensor_packer = _pack_tensors( + this_grads, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + range_agg_grads = cross_tower_utils.sum_gradients_all_reduce( + self._worker_devices, device_grad_packs, len(self._worker_devices), + spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker)) + range_agg_grads = _unpack_tensors(range_agg_grads, self._tensor_packer) + + if not aggregated_grads: + aggregated_grads = range_agg_grads + else: + assert len(aggregated_grads) == len(range_agg_grads) + for i in range(len(aggregated_grads)): + aggregated_grads[i] += range_agg_grads[i] + assert not remaining_grads + + return _ungroup_and_make_mirrored(aggregated_grads, destinations, + method_string) + + _dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 7c7b0870887465ec2fe40007695d099277db38bf..fed5505d92ef2544215069736c166a67d6141708 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -31,6 +32,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util def _make_per_device(values, devices): @@ -56,19 +58,46 @@ def _fake_mirrored(value, devices): {d: v for d, v in zip(devices, [value] * len(devices))}) +def _make_indexed_slices(values, indices, dense_shape, device): + with ops.device(device): + tensor = ops.IndexedSlices( + values=constant_op.constant(values), + indices=constant_op.constant(indices), + dense_shape=constant_op.constant(dense_shape)) + return tensor + + +def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): + return value_lib.Mirrored({ + d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices + }) + + _cpu_device = "/device:CPU:0" -class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): +class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase): - def _assert_value_equal(self, left, right): + def _assert_indexed_slices_equal(self, left, right): + self.assertIsInstance(left, ops.IndexedSlices) + self.assertIsInstance(right, ops.IndexedSlices) + self.assertEqual(device_util.resolve(left.device), + device_util.resolve(right.device)) + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + def _assert_values_equal(self, left, right): if isinstance(left, list): for l, r in zip(left, right): - self._assert_value_equal(l, r) + self._assert_values_equal(l, r) else: self.assertEqual(type(left), type(right)) self.assertEqual(left.devices, right.devices) - if context.executing_eagerly(): + if isinstance(list(left._index.values())[0], ops.IndexedSlices): + for (d, v) in left._index.items(): + self._assert_indexed_slices_equal(v, right._index[d]) + elif context.executing_eagerly(): self.assertEqual([v.numpy() for v in left._index.values()], list(right._index.values())) else: @@ -76,51 +105,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): self.assertEqual( sess.run(list(left._index.values())), list(right._index.values())) - # TODO(yuefengz): decouple the num_gpus check from distribution in - # combinations module so that we can pass in devices instead of a distribution - # strategy. - reduction_to_one_combinations = combinations.combine( - cross_tower_ops=[ - combinations.NamedObject( - "DefaultReductionToOneDeviceCrossTowerOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), - combinations.NamedObject( - "ReductionToCPUDeviceCrossTowerOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( - reduce_to_device=_cpu_device)), - combinations.NamedObject( - "AccumulateNCrossTowerOp", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( - accumulation_fn=math_ops.accumulate_n)), - ], - distribution=[ - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus - ], - mode=["graph", "eager"]) - allreduce_combinations = combinations.combine( - cross_tower_ops=[ - combinations.NamedObject( - "AllReduce", - cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)), - combinations.NamedObject( - "HierarchicalCopy", - cross_tower_ops_lib.AllReduceCrossTowerOps( - "hierarchical_copy", 8, 0, 0)), - combinations.NamedObject( - "AllReduceNoGradientRepacking", - cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)), - combinations.NamedObject( - "HierarchicalCopyAggregateSmallTensors", - cross_tower_ops_lib.AllReduceCrossTowerOps( - "hierarchical_copy", 0, 100, 10)) - ], - distribution=[combinations.mirrored_strategy_with_two_gpus], - mode=["graph", "eager"]) - - @combinations.generate(reduction_to_one_combinations + allreduce_combinations) - def testReductionAndBroadcast(self, cross_tower_ops, distribution): + def _testReductionAndBroadcast(self, cross_tower_ops, distribution): devices = distribution.worker_devices values = [constant_op.constant(float(d)) for d in range(len(devices))] @@ -143,29 +128,29 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): # test reduce() for destinations in all_destinations: - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce("mean", per_device, destinations=destinations), _fake_mirrored(mean, destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce( "mean", per_device_2, destinations=destinations), _fake_mirrored(mean_2, destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce("sum", per_device, destinations=destinations), _fake_mirrored(mean * len(devices), destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce( "sum", per_device_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations or per_device)) # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.batch_reduce( "mean", [(per_device, d1), (per_device_2, d2)]), [_fake_mirrored(mean, d1 or per_device), _fake_mirrored(mean_2, d2 or per_device_2)]) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.batch_reduce( "sum", [(per_device, d1), (per_device_2, d2)]), [_fake_mirrored(mean * len(devices), d1 or per_device), @@ -176,45 +161,205 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): if destinations is None: continue else: - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.broadcast(constant_op.constant(1.), destinations), _fake_mirrored(1., destinations)) + +class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): + # TODO(yuefengz): decouple the num_gpus check from distribution in + # combinations module so that we can pass in devices instead of a distribution + # strategy. + reduction_to_one_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "DefaultReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "ReductionToCPUDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + reduce_to_device=_cpu_device)), + combinations.NamedObject( + "AccumulateNCrossTowerOp", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + accumulation_fn=math_ops.accumulate_n)), + ], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + mode=["graph", "eager"]) + allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "AllReduce", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)), + combinations.NamedObject( + "HierarchicalCopy", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 8, 0, 0)), + combinations.NamedObject( + "AllReduceNoGradientRepacking", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)), + combinations.NamedObject( + "HierarchicalCopyAggregateSmallTensors", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 0, 100, 10)) + ], + distribution=[combinations.mirrored_strategy_with_two_gpus], + mode=["graph", "eager"]) + + @combinations.generate(reduction_to_one_combinations + allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + with distribution.scope(): + self._testReductionAndBroadcast(cross_tower_ops, distribution) + def testChooseAlgorithm(self): device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "hierarchical_copy") - self.assertEqual(result.num_packs, 8) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) # if there are only 4 devices device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "nccl") - self.assertEqual(result.num_packs, 1) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) # if devices links contain each device itself device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "hierarchical_copy") - self.assertEqual(result.num_packs, 8) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) # if not dgx1-like links device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "nccl") - self.assertEqual(result.num_packs, 1) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testSimpleReduceWithIndexedSlices(self): + devices = ["/cpu:0", "/gpu:0"] + t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) + t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + result = cross_tower_ops_lib._simple_reduce(per_device, devices[0], + math_ops.add_n, "sum") + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices with and without duplicate indices. + total_with_dups = _make_indexed_slices( + [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0]) + total_without_dups = _make_indexed_slices( + [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) + self._assert_indexed_slices_equal(total_with_dups, result) + self._assert_indexed_slices_equal(total_without_dups, result) + + @combinations.generate(combinations.combine( + cross_tower_ops_instance=[ + combinations.NamedObject( + "ReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "AllReduceCrossTowerOps", + cross_tower_ops_lib.AllReduceCrossTowerOps()) + ], + method_string=["sum", "mean"], + batch_reduce=[True, False], + mode=["graph", "eager"], + required_gpus=1)) + def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, + method_string, batch_reduce): + devices = ["/cpu:0", "/gpu:0"] + dense_shape = [5, 2] + t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) + t1 = _make_indexed_slices( + [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + + if batch_reduce: + result = cross_tower_ops_instance.batch_reduce(method_string, + [(per_device, devices)]) + else: + result = cross_tower_ops_instance.reduce(method_string, per_device, + devices) + + total_indices_with_dups = [1, 1, 3] + total_indices_without_dups = [1, 3] + + if method_string == "sum": + total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] + total_values_without_dups = [[4., 6.], [5., 6.]] + else: + assert method_string == "mean" + total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] + total_values_without_dups = [[2., 3.], [2.5, 3.]] + + total_mirrored_with_dups = _make_mirrored_indexed_slices( + devices, total_values_with_dups, total_indices_with_dups, dense_shape) + total_mirrored_without_dups = _make_mirrored_indexed_slices( + devices, total_values_without_dups, total_indices_without_dups, + dense_shape) + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices, as well as when the duplicate indices are summed up. + if batch_reduce: + total_mirrored_with_dups = [total_mirrored_with_dups] + total_mirrored_without_dups = [total_mirrored_without_dups] + + self._assert_values_equal(total_mirrored_with_dups, result) + self._assert_values_equal(total_mirrored_without_dups, result) + + +class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, + CrossTowerOpsTestBase): + + worker_devices = [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + multi_worker_allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "MultiWorkerAllReduce", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReducePack", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReduceAggregation", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)), + combinations.NamedObject( + "MultiWorkerAllReduceMultipleSpecs", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, [("pscpu/pscpu", 2, 100), + ("xring", 2, -1)], 0, 0, 0)), + ], + distribution=[ + combinations.multi_worker_strategy_with_cpu, + combinations.multi_worker_strategy_with_one_gpu, + combinations.multi_worker_strategy_with_two_gpus + ], + mode=["graph"]) + + @combinations.generate(multi_worker_allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + with distribution.scope(): + self._testReductionAndBroadcast(cross_tower_ops, distribution) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index fc04e2195f6d305e0f7c642f24c355286f1a8cfa..2bb088e704c584598b863b1b836166af2a5bb12c 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -21,9 +21,12 @@ from __future__ import print_function import collections as pycoll from tensorflow.contrib import nccl +from tensorflow.contrib.all_reduce.python import all_reduce +from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -156,6 +159,148 @@ def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, return (grad, v), None +def group_device_names(devices, group_size): + """Group device names into groups of group_size. + + Args: + devices: a list of canonical device strings. + group_size: integer which is equal to or greater than 1. + + Returns: + list of lists of devices, where each inner list is group_size long, + and each device appears at least once in an inner list. If + len(devices) % group_size == 0 then each device will appear exactly once. + + Raises: + ValueError: if group_size > len(devices) + """ + num_devices = len(devices) + if group_size > num_devices: + raise ValueError( + 'only %d devices, but group_size=%d' % (num_devices, group_size)) + num_groups = ( + num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) + groups = [[] for i in range(num_groups)] + for i in range(num_groups * group_size): + groups[i % num_groups].append(devices[i % num_devices]) + return groups + + +def split_grads_by_size(threshold_size, device_grads): + """Break gradients into two sets according to tensor size. + + Args: + threshold_size: int size cutoff for small vs large tensor. + device_grads: List of lists of (gradient, variable) tuples. The outer + list is over devices. The inner list is over individual gradients. + + Returns: + small_grads: Subset of device_grads where shape is <= threshold_size + elements. + large_grads: Subset of device_grads where shape is > threshold_size + elements. + """ + small_grads = [] + large_grads = [] + for dl in device_grads: + small_dl = [] + large_dl = [] + for (g, v) in dl: + tensor_size = g.get_shape().num_elements() + if tensor_size <= threshold_size: + small_dl.append([g, v]) + else: + large_dl.append([g, v]) + if small_dl: + small_grads.append(small_dl) + if large_dl: + large_grads.append(large_dl) + return small_grads, large_grads + + +def sum_grad_and_var_all_reduce(grad_and_vars, + num_workers, + alg, + gpu_indices, + aux_devices=None, + num_shards=1): + """Apply all-reduce algorithm over specified gradient tensors.""" + with ops.name_scope('allreduce'): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + scaled_grads = [g for g, _ in grad_and_vars] + if alg == 'nccl': + summed_grads = nccl.all_sum(scaled_grads) + elif alg == 'xring': + summed_grads = all_reduce.build_ring_all_reduce( + scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add) + elif alg == 'nccl/xring': + summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, + math_ops.add) + elif alg == 'nccl/rechd': + summed_grads = all_reduce.build_nccl_then_recursive_hd( + scaled_grads, math_ops.add) + elif alg == 'nccl/pscpu': + summed_grads = all_reduce.build_nccl_then_shuffle( + scaled_grads, aux_devices, math_ops.add, math_ops.add_n) + elif alg == 'pscpu/pscpu': + second_gather_devices = aux_devices[:num_shards] + summed_grads = all_reduce.build_shuffle_then_shuffle( + scaled_grads, aux_devices, second_gather_devices, math_ops.add_n) + elif alg in ['pscpu', 'psgpu']: + summed_grads = all_reduce.build_shuffle_all_reduce( + scaled_grads, aux_devices, math_ops.add_n) + else: + raise ValueError('unsupported all_reduce alg: ', alg) + + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result + + +def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg, + num_shards, gpu_indices): + """Apply all-reduce algorithm over specified gradient tensors. + + Args: + dev_prefixes: list of prefix strings to use to generate PS device names. + tower_grads: the gradients to reduce. + num_workers: number of worker processes across entire job. + alg: the all-reduce algorithm to apply. + num_shards: alg-specific sharding factor. + gpu_indices: indices of local GPUs in order usable for ring-reduce. + + Returns: + list of reduced tensors + """ + alg_contains_shuffle = any([n in alg for n in ['pscpu', 'psgpu']]) + is_hierarchical = '/' in alg + if 'pscpu' in alg: + aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] + elif 'psgpu' in alg: + aux_devices = [ + prefix + '/gpu:%d' % i + for i in range(len(gpu_indices)) + for prefix in dev_prefixes + ] + else: + aux_devices = ['/job:localhost/cpu:0'] + # Auxiliary devices for hierarchical all-reduces. + aux_device_groups = group_device_names( + aux_devices, num_shards if alg_contains_shuffle else 1) + group_index = 0 + reduced_gv_list = [] + for grad_and_vars in zip(*tower_grads): + reduced_gv_list.append( + sum_grad_and_var_all_reduce( + grad_and_vars, num_workers, alg, gpu_indices, aux_devices + if is_hierarchical else aux_device_groups[group_index], num_shards)) + group_index = (group_index + 1) % len(aux_device_groups) + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return new_tower_grads + + def extract_ranges(index_list, range_size_limit=32): """Extract consecutive ranges and singles from index_list. @@ -328,7 +473,7 @@ def unpack_small_tensors(tower_grads, packing): for dev_idx, gv_list in enumerate(tower_grads): gv_list = list(gv_list) new_gv_list = gv_list[num_packed:] - for i in xrange(0, num_packed): + for i in range(num_packed): k = '%d:%d' % (dev_idx, i) gpt = packing[k] gv = unpack_grad_tuple(gv_list[i], gpt) @@ -337,3 +482,46 @@ def unpack_small_tensors(tower_grads, packing): new_gv_list.insert(idx, gv[gi]) new_tower_grads.append(new_gv_list) return new_tower_grads + + +def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): + """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" + if any(isinstance(v, ops.IndexedSlices) for v in values): + return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access + else: + return accumulation_fn(values) + + +def divide_by_n_tensors_or_indexed_slices(value, n): + if isinstance(value, ops.IndexedSlices): + value = gradients_impl._HandleNestedIndexedSlices(value) # pylint: disable=protected-access + return ops.IndexedSlices( + value.values / n, value.indices, value.dense_shape) + else: + return value / n + + +def copy_tensor_or_indexed_slices_to_device(value, device): + with ops.device(device): + if isinstance(value, ops.IndexedSlices): + copied_values = array_ops.identity(value.values) + copied_indices = array_ops.identity(value.indices) + copied_shape = array_ops.identity(value.dense_shape) + result = ops.IndexedSlices(copied_values, copied_indices, copied_shape) + else: + result = array_ops.identity(value) + return result + + +def contains_indexed_slices(value): + """Check whether the value is `IndexedSlices` or contains `IndexedSlices`.""" + if isinstance(value, ops.IndexedSlices): + return True + elif isinstance(value, (list, tuple)) and value: + return any(contains_indexed_slices(v) for v in value) + elif isinstance(value, value_lib.DistributedValues): + return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access + elif isinstance(value, value_lib.MapOutput): + return contains_indexed_slices(value.get()) + else: + return False diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ef8db681503dcef8c72f641455dbb999cef05cf --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py @@ -0,0 +1,152 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cross_tower_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util + + +class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): + + def _assert_values_equal(self, left, right): + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + @test_util.run_in_graph_and_eager_modes() + def testAggregateTensors(self): + t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes() + def testAggregateIndexedSlices(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes() + def testDivideTensor(self): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes() + def testDivideIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes() + def testIsIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_List(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_Tuple(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_PerDevice(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_PerDeviceMapOutput(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({ + "/gpu:0": value_lib.MapOutput([t0]), + "/cpu:0": value_lib.MapOutput([t1])}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyTensor(self): + with ops.device("/cpu:0"): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyIndexedSlices(self): + with ops.device("/cpu:0"): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py new file mode 100644 index 0000000000000000000000000000000000000000..75ecd90dcffa7a786b78238ef453c4c8e4346afa --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -0,0 +1,148 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Keras Sequential and Functional models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import keras as keras_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.framework import test_util +from tensorflow.python.keras import testing_utils +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import rmsprop + +_RANDOM_SEED = 1337 +_TRAIN_SIZE = 200 +_INPUT_SIZE = (10,) +_NUM_CLASS = 2 + + +def simple_sequential_model(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) + model.add(keras.layers.Dropout(0.1)) + model.add(keras.layers.Dense(_NUM_CLASS, activation='softmax')) + return model + + +def simple_functional_model(): + a = keras.layers.Input(shape=_INPUT_SIZE) + b = keras.layers.Dense(16, activation='relu')(a) + b = keras.layers.Dropout(0.1)(b) + b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b) + model = keras.models.Model(inputs=[a], outputs=[b]) + return model + + +def get_ds_train_input_fn(): + np.random.seed(_RANDOM_SEED) + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_train = keras.utils.to_categorical(y_train) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset = dataset.batch(32) + return dataset + + +def get_ds_test_input_fn(): + np.random.seed(_RANDOM_SEED) + _, (x_test, y_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_test = keras.utils.to_categorical(y_test) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_test, y_test)) + dataset = dataset.batch(32) + return dataset + + +class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): + + def setUp(self): + self._base_dir = os.path.join(self.get_temp_dir(), + 'keras_mirrored_strategy_test') + gfile.MakeDirs(self._base_dir) + self._config = run_config_lib.RunConfig( + tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) + + def tearDown(self): + writer_cache.FileWriterCache.clear() + if os.path.isdir(self._base_dir): + gfile.DeleteRecursively(self._base_dir) + + def test_train_functional_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_functional_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + def test_train_sequential_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_sequential_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6bf143098c1bba64d47efce1bfface7682683d --- /dev/null +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -0,0 +1,438 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for V1 metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import variables + + +def _labeled_dataset_fn(): + # First four batches of x: labels, predictions -> (labels == predictions) + # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False + # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False + # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False + # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True + return dataset_ops.Dataset.range(1000).map( + lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4) + + +def _boolean_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # T, T -> TP; F, T -> FP; T, F -> FN + # F, F -> TN; T, T -> TP; F, T -> FP + # T, F -> FN; F, F -> TN; T, T -> TP + # F, T -> FP; T, F -> FN; F, F -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [True, True, False, False]}).repeat().batch(3) + + +def _threshold_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN + # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP + # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP + # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3) + + +def _regression_dataset_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [1., .5, 1., 0.], + "predictions": [1., .75, .25, 0.]}).repeat() + + +def all_combinations(): + return combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + mode=["graph"]) + + +# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, +# metrics.precision_at_k +class MetricsV1Test(test.TestCase, parameterized.TestCase): + + def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): + with ops.Graph().as_default(), distribution.scope(): + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + value, update = distribution.call_for_each_tower( + metric_fn, iterator.get_next()) + update = distribution.group(update) + self.evaluate(variables.local_variables_initializer()) + # TODO(josh11b): Once we switch to using a global batch size for input, + # replace "distribution.num_towers" with "1". + batches_per_update = distribution.num_towers + + # Update variables using the first `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), + 0.001, msg="After first update") + + # Update variables using the second `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(2 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After second update") + + if batches_per_update == 1: # Consume 4 input batches + self.evaluate(update) + self.assertAllClose(expected_fn(3 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After third update") + self.evaluate(update) + self.assertAllClose(expected_fn(4 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After fourth update") + + @combinations.generate(all_combinations()) + def testMean(self, distribution): + def _dataset_fn(): + return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4) + + def _expected_fn(num_batches): + # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. + return num_batches * 2 - 0.5 + + self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) + + @combinations.generate(all_combinations()) + def testAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.accuracy(labels, predictions) + + def _expected_fn(num_batches): + return [3./4, 3./8, 3./12, 4./16][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanPerClassAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_per_class_accuracy( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1., 1., 1., 0., 0.]), + mean([0.5, 0.5, 0.5, 0., 0.]), + mean([1./3, 1./3, 0.5, 0., 0.]), + mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanIOU(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_iou( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch + mean([1./4, 1./4, 1./3, 0., 0.]), + mean([1./6, 1./6, 1./5, 0., 0.]), + mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanTensor(self, distribution): + def _dataset_fn(): + dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) + # Want to produce a fixed, known shape, so drop remainder when batching. + dataset = dataset.apply(batching.batch_and_drop_remainder(4)) + return dataset + + def _expected_fn(num_batches): + # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2 + # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1 + # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches + # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1 + first = 2. * num_batches - 2. + return [first, first + 1., first + 2., first + 3.] + + self._test_metric( + distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCROC(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.5, 7./9, 0.8, 0.75][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCPR(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="PR", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[0.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 3., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [3.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecision(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 0.5, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecisionAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecall(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecallAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1./32, 0.208333, 0.15625][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRootMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.root_mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 0.176777, 0.456435, 0.395285][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSensitivityAtSpecificity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.sensitivity_at_specificity(labels, predictions, 0.8) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSpecificityAtSensitivity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.specificity_at_sensitivity(labels, predictions, 0.95) + + def _expected_fn(num_batches): + return [0., 1./3, 0.5, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index d2054715f11c47b8fc3bd73288fd13c0fd5e71e8..5c056a7c73def2f1fb4bbe0df4d3f82fdabda3df 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -207,11 +207,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm=renorm, update_ops_in_tower_mode=not update_ops_in_cross_tower_mode) - # Disable prefetching since that makes the specific input on each device - # to be non deterministic, and this test relies on specific input being - # on each device. + # Make sure prefetching is disabled since that makes the + # specific input on each device to be non deterministic, and + # this test relies on specific input being on each device. if isinstance(distribution, mirrored_strategy.MirroredStrategy): - distribution._prefetch_on_device = False + self.assertFalse(distribution._prefetch_on_device) iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 8237b23dbbdb10c053de53880d6838113b99be2d..900aa10e93e8881aa236bac8a2873d5c5531c6f6 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib import threading import six @@ -30,6 +31,7 @@ from tensorflow.python.eager import tape from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import coordinator from tensorflow.python.training import device_util @@ -39,6 +41,16 @@ from tensorflow.python.training import distribute as distribute_lib # TODO(josh11b): Replace asserts in this file with if ...: raise ... +@contextlib.contextmanager +def _enter_graph(g): + if context.executing_eagerly(): + with g.as_default(), context.eager_mode(): + yield + else: + with g.as_default(): + yield + + def _cpu_device(device): cpu_device = tf_device.DeviceSpec.from_string(device) cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) @@ -73,9 +85,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): assert len(set(devices)) == len(devices), ( "No duplicates allowed in `devices` argument.") # TODO(josh11b): Require at least 2 devices? - self._devices = devices - self._canonical_device_set = set( - [device_util.canonicalize(d) for d in devices]) + self._devices = [device_util.resolve(d) for d in devices] + self._canonical_device_set = set(self._devices) self._device_index = values.PerDevice( dict((d, i) for i, d in enumerate(devices))) self._cross_tower_ops = cross_tower_ops @@ -108,13 +119,19 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] - kwargs["name"] = "%s/replica_%d" % (var0name, i) + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): - initial_value = index[devices[0]].value() + kwargs["initial_value"] = array_ops.identity( + index[devices[0]].value()) else: - initial_value = index[devices[0]].initial_value - kwargs["initial_value"] = array_ops.identity(initial_value) + def initial_value_fn(device=d): + with ops.device(device): + return array_ops.identity(index[devices[0]].initial_value) + kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) @@ -245,8 +262,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): {t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) - merge_result = threads[0].merge_fn( - self, *merge_args, **merge_kwargs) + # We capture the name_scope of the MTT when we call merge_fn + # to ensure that if we have opened a name scope in the MTT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MTT and assume it is + # the same for all other MTTs. + mtt_captured_name_scope = threads[0].captured_name_scope + with ops.name_scope(mtt_captured_name_scope): + merge_result = threads[0].merge_fn( + self, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device(t.device, merge_result) finally: @@ -320,6 +344,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored) + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + if isinstance(tower_local_var, values.TowerLocalVariable): + return math_ops.add_n(self.unwrap(tower_local_var)) + assert isinstance(tower_local_var, values.Mirrored) + return array_ops.identity(tower_local_var.get()) + def _fetch(self, val, destination, fn): """Return a copy of `val` or `fn(val)` on `destination`.""" if isinstance(val, values.TowerLocalVariable): @@ -386,7 +417,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # pylint: disable=protected-access return list(colocate_with._index.keys()) elif isinstance(colocate_with, six.string_types): - return [colocate_with] + return [device_util.resolve(colocate_with)] + elif isinstance(colocate_with, list): + return [device_util.resolve(d) for d in colocate_with] else: return colocate_with @@ -413,6 +446,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self.merge_args = None self.merge_kwargs = None self.merge_result = None + self.captured_name_scope = None # We use a thread.Event for the main thread to signal when this # thread should start running (`should_run`), and another for # this thread to transfer control back to the main thread @@ -436,13 +470,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._variable_creator_stack = self.graph._variable_creator_stack[:] self._captured_var_scope = variable_scope.get_variable_scope() # Adding a "/" at end lets us re-enter this scope later. - self._captured_name_scope = self.graph.get_name_scope() - if self._captured_name_scope: - self._captured_name_scope += "/" + self._name_scope = self.graph.get_name_scope() + if self._name_scope: + self._name_scope += "/" if self.tower_id > 0: - if not self._captured_name_scope: - self._captured_name_scope = "" - self._captured_name_scope += "tower_%d/" % self.tower_id + if not self._name_scope: + self._name_scope = "" + self._name_scope += "tower_%d/" % self.tower_id def run(self): # pylint: disable=protected-access @@ -455,10 +489,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with self.coord.stop_on_exception(), \ context.context()._mode(self.context_mode), \ context.context().device_policy(self.context_device_policy), \ - self.graph.as_default(), \ + _enter_graph(self.graph), \ MirroredTowerContext(self.distribution, self.tower_id), \ ops.device(self.device), \ - ops.name_scope(self._captured_name_scope), \ + ops.name_scope(self._name_scope), \ variable_scope.variable_scope( self._captured_var_scope, reuse=self.tower_id > 0), \ variable_scope.variable_creator_scope(self.variable_creator_fn): @@ -484,6 +518,10 @@ class MirroredTowerContext(distribute_lib.TowerContext): t.merge_fn = fn t.merge_args = args t.merge_kwargs = kwargs + t.captured_name_scope = t.graph.get_name_scope() + # Adding a "/" at end lets us re-enter this scope later. + if t.captured_name_scope: + t.captured_name_scope += "/" t.has_paused.set() t.should_run.wait() t.should_run.clear() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 3635bd2e34f88ab05a3ddce1728fd53c5b7149b3..bccd278847e3c87080af3cb15665e7a0d802d8fb 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -28,9 +28,12 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib @@ -116,7 +119,6 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): self.assertEqual(expected, self.evaluate(unwrapped[0])) -@test_util.with_c_api class MirroredStrategyVariableCreationTest(test.TestCase): config = config_pb2.ConfigProto() @@ -436,6 +438,98 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("foo/" + name + ":0", v0.name) self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + # variable_scope.variable() respects name scopes when creating + # variables. On the other hand variable_scope.get_variable() ignores name + # scopes when creating variables. We test both methods of creating variables + # to make sure that we have the same variable names in both cases. + def testNameScopeWithVariable(self): + def in_cross_tower(_): + c = variable_scope.variable(1.0, name="c") + return c + + def model_fn(): + b = variable_scope.variable(1.0, name="b") + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.variable(1.0, name="a") + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("main/a:0", a0.name) + self.assertEquals("main/a/replica_1:0", a1.name) + self.assertEquals("main/b:0", b0.name) + self.assertEquals("main/b/replica_1:0", b1.name) + self.assertEquals("main/foo/c:0", c0.name) + self.assertEquals("main/foo/c/replica_1:0", c1.name) + + def testNameScopeWithGetVariable(self): + def in_cross_tower(_): + c = variable_scope.get_variable("c", [1]) + return c + + def model_fn(): + b = variable_scope.get_variable("b", [1]) + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.get_variable("a", [1]) + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("a:0", a0.name) + self.assertEquals("a/replica_1:0", a1.name) + self.assertEquals("b:0", b0.name) + self.assertEquals("b/replica_1:0", b1.name) + self.assertEquals("c:0", c0.name) + self.assertEquals("c/replica_1:0", c1.name) + + def testDynamicRnnVariables(self): + def model_fn(): + inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) + cell_fw = rnn_cell_impl.LSTMCell(300) + cell_bw = rnn_cell_impl.LSTMCell(300) + (outputs, _) = rnn.bidirectional_dynamic_rnn( + cell_fw, + cell_bw, + inputs, + dtype=dtypes.float32) + return outputs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + # Two variables are created by the RNN layer. + self.assertEquals(2, len(result)) + for v in result: + self.assertIsInstance(v, values.DistributedValues) + _, v1 = dist.unwrap(v) + self.assertStartsWith(v1.name, "tower_1/") + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index a1ef0ecc77a8e8432dfa4eb6da7c324b371dab70..61cbe6df813bb28bf8baa83d9e28ffafc4f0cbb8 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -27,7 +27,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import distribute as distribute_lib -@test_util.with_c_api class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): def _get_distribution_strategy(self): @@ -53,7 +52,6 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) -@test_util.with_c_api class VariableCreatorStackTest(test.TestCase): def testCreatorStacksAreThreadLocal(self): diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 8277e1e7919e86ef616b31d0986589dcc9c49bbd..4fdb9bf69b4f6ad76b79fd298f5303f24a1bd455 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import monitor as monitor_lib from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import ops @@ -65,7 +66,7 @@ class MonitorTest(test.TestCase, parameterized.TestCase): step_function, _ = single_loss_example( lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) - with self.test_session() as sess: + with session.Session() as sess, context.eager_mode(): with self.assertRaisesRegexp(ValueError, "Should not provide"): _ = monitor_lib.Monitor(step_function, sess) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py index a552b370ebf359464afcaf3211119e73434e0dfb..0f21a427320510635279f80c11711e81715ec37c 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py @@ -121,7 +121,7 @@ class MultiWorkerMirroredStrategy(MirroredStrategy): worker: [device_util.canonicalize(worker, '/device:CPU:0')] for worker in self._workers } - self._devices = nest.flatten(self._worker_device_map.values()) + self._devices = nest.flatten(self._worker_device_map) super(MultiWorkerMirroredStrategy, self).__init__( devices=self._devices, prefetch_on_device=prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py index ee7588163e42ee3c31dd9fd25fc53e3483f0fbee..09c859b32a3150b95fbfcfa5b62b5eca426ddf18 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py @@ -25,11 +25,9 @@ from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.training import server_lib -@test_util.with_c_api class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, strategy_test_lib.DistributionTestBase): diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 09b6d4a515ab46879520f304cd5ef60469512380..7f4bab9d93814eb70a2a1586fc291a16b2766b90 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -102,6 +102,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.device(self._device), distribute_lib.UpdateContext(self._device): return fn(*args, **kwargs) + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + return array_ops.identity(tower_local_var) + def _fetch(self, val, destination, fn): """Return a copy of `val` or `fn(val)` on `destination`.""" with ops.device(self._device): diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 7101ed0756f44b846f10ddc6d429afe005a2f196..7aad8a953cbedd30b48739416e74b3dc164dc4cd 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -24,7 +24,6 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util -@test_util.with_c_api class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def _get_distribution_strategy(self): diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py index 713494d603b855be2863af9f24ab98d4cf048042..a0b452fc2d445d1cf7dbf5e8fe0e29edef516207 100644 --- a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py +++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py @@ -44,7 +44,6 @@ class CanonicalizeVariableNameTest(test.TestCase): self.assertEquals("foo_a", self._canonicalize("foo_a")) -@test_util.with_c_api class SharedVariableCreatorTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 759f3c359975bae6c892b65d3ce24c59e9f74116..9572ade8e497fa13a7ca0746399d3e0237ee79fd 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -35,10 +35,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.training import checkpointable from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -65,9 +65,10 @@ class DistributedValues(object): device = device_util.canonicalize(device) try: return self._index[device] - except KeyError: - raise ValueError("Device %s not found in %s (current device %s)" % - (device, self._index.keys(), device_util.current())) + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) def on_device(self, device): device = device_util.canonicalize(device) diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 9aeef9fa3e86f25ba2544236fd802c7162f4e40e..1c95758d96aba47e9581dde6411763e98b99a968 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -42,7 +42,6 @@ from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest -@test_util.with_c_api class DistributedValuesTest(test.TestCase): def testGetEager(self): @@ -81,7 +80,6 @@ class DistributedValuesTest(test.TestCase): v = values.DistributedValues({"/device:cpu:0": 42}) -@test_util.with_c_api class DistributedDelegateTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() @@ -164,7 +162,6 @@ def _make_mirrored(): return v, devices, mirrored -@test_util.with_c_api class RegroupAndSelectDeviceTest(test.TestCase): def _is_per_device(self, result, expected, klass=values.PerDevice): @@ -317,7 +314,6 @@ class RegroupAndSelectDeviceTest(test.TestCase): merged_estimator_spec)) -@test_util.with_c_api class PerDeviceDatasetTest(test.TestCase): config = config_pb2.ConfigProto() @@ -564,7 +560,6 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): multi_worker_iterator.get_next() -@test_util.with_c_api class MirroredVariableTest(test.TestCase): config = config_pb2.ConfigProto() @@ -741,7 +736,6 @@ def _make_tower_local(method): return v, tower_local -@test_util.with_c_api class TowerLocalVariableTest(test.TestCase): config = config_pb2.ConfigProto() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index a1d56066b417ddd103d17a528d2922ca5853bd55..51f7028566f0119fa58875f0fe47861e59f412c2 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -94,7 +94,7 @@ cuda_py_test( cuda_py_test( name = "distribution_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/distribution_test.py"], additional_deps = [ ":distributions_py", @@ -337,7 +337,7 @@ cuda_py_test( cuda_py_test( name = "mvn_tril_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/mvn_tril_test.py"], additional_deps = [ ":distributions_py", @@ -710,6 +710,7 @@ cuda_py_test( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:client_testlib", ], + shard_count = 4, tags = ["noasan"], # times out, http://b/78588814 ) @@ -939,6 +940,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "fill_triangular_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/fill_triangular_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "gumbel_test", size = "small", @@ -1031,6 +1051,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "matrix_inverse_tril_test", + size = "medium", + srcs = ["python/kernel_tests/bijectors/matrix_inverse_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "real_nvp_test", size = "small", @@ -1098,6 +1137,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "scale_tril_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/scale_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "sigmoid_test", size = "small", @@ -1215,6 +1273,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "transform_diagonal_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/transform_diagonal_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "weibull_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index ddf59891e626a85e6c917ac74b3cfaabf16eb15d..5cec93c4df2e970f203253be6342bb292f296eb0 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== """Classes representing statistical distributions and ops for working with them. - -See the @{$python/contrib.distributions} guide. """ from __future__ import absolute_import from __future__ import division @@ -32,6 +30,7 @@ from tensorflow.contrib.distributions.python.ops.conditional_distribution import from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular +from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse @@ -156,6 +155,7 @@ _allowed_symbols = [ 'kl_divergence', 'RegisterKL', 'fill_triangular', + 'fill_triangular_inverse', 'matrix_diag_transform', 'reduce_weighted_logsumexp', 'softplus_inverse', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py index 59d549b7b80a3d80d0b8409542eb6583f645bdaa..f2bb2d3325a7cc6ec5803860600149522752a4c0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py @@ -448,8 +448,7 @@ class _BatchReshapeTest(object): else: with self.test_session(): - with self.assertRaisesOpError(r"`batch_shape` size must match " - r"`distributions.batch_shape` size"): + with self.assertRaisesOpError(r"Shape sizes do not match."): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, @@ -457,8 +456,13 @@ class _BatchReshapeTest(object): def test_non_positive_shape(self): dims = 2 - new_batch_shape = [-1, -2] # -1*-2=2 so will pass size check. - old_batch_shape = [2] + old_batch_shape = [4] + if self.is_static_shape: + # Unknown first dimension does not trigger size check. Note that + # any dimension < 0 is treated statically as unknown. + new_batch_shape = [-1, 0] + else: + new_batch_shape = [-2, -2] # -2 * -2 = 4, same size as the old shape. new_batch_shape_ph = ( constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape @@ -471,7 +475,7 @@ class _BatchReshapeTest(object): mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: - with self.assertRaisesRegexp(ValueError, r".*must be positive.*"): + with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, @@ -479,7 +483,7 @@ class _BatchReshapeTest(object): else: with self.test_session(): - with self.assertRaisesOpError(r".*must be positive.*"): + with self.assertRaisesOpError(r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py index 8b279ebcd908b6f375b35594ac5f3db9228a1e31..f8a52615b0f3f5ad0c7e01e0f76c7d7a6b455ef7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py @@ -59,7 +59,7 @@ class ConditionalBijectorTest(test.TestCase): for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]: method = getattr(b, name) with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): - method(1., event_ndims=0., arg1="b1", arg2="b2") + method(1., event_ndims=0, arg1="b1", arg2="b2") if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py new file mode 100644 index 0000000000000000000000000000000000000000..caeaf2a0c6e4fff28c0edd82cb09ca0bcee85fc3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class FillTriangularBijectorTest(test.TestCase): + """Tests the correctness of the FillTriangular bijector.""" + + @test_util.run_in_graph_and_eager_modes() + def testBijector(self): + x = np.float32(np.array([1., 2., 3.])) + y = np.float32(np.array([[3., 0.], + [2., 1.]])) + + b = bijectors.FillTriangular() + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + self.assertAllClose(fldj, 0.) + + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(ildj, 0.) + + @test_util.run_in_graph_and_eager_modes() + def testShape(self): + x_shape = tensor_shape.TensorShape([5, 4, 6]) + y_shape = tensor_shape.TensorShape([5, 4, 3, 3]) + + b = bijectors.FillTriangular(validate_args=True) + + x = array_ops.ones(shape=x_shape, dtype=dtypes.float32) + y_ = b.forward(x) + self.assertAllEqual(y_.shape.as_list(), y_shape.as_list()) + x_ = b.inverse(y_) + self.assertAllEqual(x_.shape.as_list(), x_shape.as_list()) + + y_shape_ = b.forward_event_shape(x_shape) + self.assertAllEqual(y_shape_.as_list(), y_shape.as_list()) + x_shape_ = b.inverse_event_shape(y_shape) + self.assertAllEqual(x_shape_.as_list(), x_shape.as_list()) + + y_shape_tensor = self.evaluate( + b.forward_event_shape_tensor(x_shape.as_list())) + self.assertAllEqual(y_shape_tensor, y_shape.as_list()) + x_shape_tensor = self.evaluate( + b.inverse_event_shape_tensor(y_shape.as_list())) + self.assertAllEqual(x_shape_tensor, x_shape.as_list()) + + @test_util.run_in_graph_and_eager_modes() + def testShapeError(self): + + b = bijectors.FillTriangular(validate_args=True) + + x_shape_bad = tensor_shape.TensorShape([5, 4, 7]) + with self.assertRaisesRegexp(ValueError, "is not a triangular number"): + b.forward_event_shape(x_shape_bad) + with self.assertRaisesOpError("is not a triangular number"): + self.evaluate(b.forward_event_shape_tensor(x_shape_bad.as_list())) + + y_shape_bad = tensor_shape.TensorShape([5, 4, 3, 2]) + with self.assertRaisesRegexp(ValueError, "Matrix must be square"): + b.inverse_event_shape(y_shape_bad) + with self.assertRaisesOpError("Matrix must be square"): + self.evaluate(b.inverse_event_shape_tensor(y_shape_bad.as_list())) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py new file mode 100644 index 0000000000000000000000000000000000000000..18397035571561731698b06d90e20dc74e3cf83c --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class MatrixInverseTriLBijectorTest(test.TestCase): + """Tests the correctness of the Y = inv(tril) transformation.""" + + @test_util.run_in_graph_and_eager_modes() + def testComputesCorrectValues(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + self.assertEqual("matrix_inverse_tril", inv.name) + x_ = np.array([[0.7, 0., 0.], + [0.1, -1., 0.], + [0.3, 0.25, 0.5]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -6. * np.sum(np.log(np.abs(np.diag(x_)))) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testOneByOneMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[5.]], dtype=np.float32) + x_inv_ = np.array([[0.2]], dtype=np.float32) + expected_fldj_ = np.log(0.04) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testZeroByZeroMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.eye(0, dtype=np.float32) + x_inv_ = np.eye(0, dtype=np.float32) + expected_fldj_ = 0. + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testBatch(self): + # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape + # (2, 1). + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[[[1., 0.], + [2., 3.]]], + [[[4., 0.], + [5., -6.]]]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -4. * np.sum( + np.log(np.abs(np.diagonal(x_, axis1=-2, axis2=-1))), axis=-1) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3) + self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputRankTooLow(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([0.1], dtype=np.float32) + rank_error_msg = "must have rank at least 2" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + # TODO(b/80481923): Figure out why these assertions fail, and fix them. + ## def testErrorOnInputNonSquare(self): + ## inv = bijectors.MatrixInverseTriL(validate_args=True) + ## x_ = np.array([[1., 2., 3.], + ## [4., 5., 6.]], dtype=np.float32) + ## square_error_msg = "must be a square matrix" + ## with self.test_session(): + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputNotLowerTriangular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 2.], + [3., 4.]], dtype=np.float32) + triangular_error_msg = "must be lower triangular" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputSingular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 0.], + [0., 0.]], dtype=np.float32) + nonsingular_error_msg = "must have all diagonal entries nonzero" + with self.test_session(): + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 46f2c63f9b0f78b25bb1948e6ea55ab20c5cfa6e..d44e49b4874a5b91f7633cd9c97dbb1a7da70f27 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -22,15 +22,12 @@ import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test -@test_util.with_c_api class _ReshapeBijectorTest(object): """Base class for testing the reshape transformation. @@ -265,7 +262,6 @@ class _ReshapeBijectorTest(object): raise NotImplementedError("Subclass failed to implement `build_shapes`.") -@test_util.with_c_api class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -305,21 +301,13 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): bijector, x, y, event_ndims=2, rtol=1e-6, atol=0) def testInvalidDimensionsOpError(self): - if ops._USE_C_API: - error_message = "Invalid value in tensor used for shape: -2" - else: - error_message = "elements must be either positive integers or `-1`." - self._testInvalidDimensionsOpError(error_message) + self._testInvalidDimensionsOpError( + "Invalid value in tensor used for shape: -2") def testInputOutputMismatchOpError(self): - if ops._USE_C_API: - error_message = "Cannot reshape a tensor with" - else: - error_message = "Input to reshape is a tensor with" - self._testInputOutputMismatchOpError(error_message) + self._testInputOutputMismatchOpError("Cannot reshape a tensor with") -@test_util.with_c_api class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -341,7 +329,6 @@ class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): self._testInputOutputMismatchOpError("Input to reshape is a tensor with") -@test_util.with_c_api class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py new file mode 100644 index 0000000000000000000000000000000000000000..566a7b3dff9b5d97a1cb143e0b32fc15984c3a02 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py @@ -0,0 +1,69 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class ScaleTriLBijectorTest(test.TestCase): + """Tests the correctness of the ScaleTriL bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testComputesCorrectValues(self): + shift = 1.61803398875 + x = np.float32(np.array([-1, .5, 2])) + y = np.float32(np.array([[np.exp(2) + shift, 0.], + [.5, np.exp(-1) + shift]])) + + b = bijectors.ScaleTriL(diag_bijector=bijectors.Exp(), + diag_shift=shift) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + @test_util.run_in_graph_and_eager_modes() + def testInvertible(self): + + # Generate random inputs from an unconstrained space, with + # event size 6 to specify 3x3 triangular matrices. + batch_shape = [2, 1] + x = np.float32(np.random.randn(*(batch_shape + [6]))) + b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(), + diag_shift=3.14159) + y = self.evaluate(b.forward(x)) + self.assertAllEqual(y.shape, batch_shape + [3, 3]) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(fldj, -ildj) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 45760a29ee42835da69ef63803ccec7ce82a5a8f..795f1993ba5c31bf5a26333f31f1bc73125bff07 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -151,16 +151,24 @@ class SinhArcsinhBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval(), rtol=1e-4, atol=0.) self.assertAllClose(x, bijector.inverse(y).eval(), rtol=1e-4, atol=0.) - # Do the numpy calculation in float128 to avoid inf/nan. - y_float128 = np.float128(y) - self.assertAllClose( - np.log(np.cosh( - np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( - y_float128**2 + 1)) - - np.log(tailweight), - bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), - rtol=1e-4, - atol=0.) + # On IBM PPC systems, longdouble (np.float128) is same as double except that it can have more precision. + # Type double being of 8 bytes, can't hold square of max of float64 (which is also 8 bytes) and + # below test fails due to overflow error giving inf. So this check avoids that error by skipping square + # calculation and corresponding assert. + + if np.amax(y) <= np.sqrt(np.finfo(np.float128).max) and \ + np.fabs(np.amin(y)) <= np.sqrt(np.fabs(np.finfo(np.float128).min)): + + # Do the numpy calculation in float128 to avoid inf/nan. + y_float128 = np.float128(y) + self.assertAllClose( + np.log(np.cosh( + np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( + y_float128**2 + 1)) - + np.log(tailweight), + bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + rtol=1e-4, + atol=0.) self.assertAllClose( -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6428a68702274fae384ae3de6d03f7ca126e2346 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class TransformDiagonalBijectorTest(test.TestCase): + """Tests correctness of the TransformDiagonal bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + @test_util.run_in_graph_and_eager_modes() + def testBijector(self): + x = np.float32(np.random.randn(3, 4, 4)) + + y = x.copy() + for i in range(x.shape[0]): + np.fill_diagonal(y[i, :, :], np.exp(np.diag(x[i, :, :]))) + + exp = bijectors.Exp() + b = bijectors.TransformDiagonal(diag_bijector=exp) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=2)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllEqual( + fldj, + self.evaluate(exp.forward_log_det_jacobian( + np.array([np.diag(x_mat) for x_mat in x]), + event_ndims=1))) + self.assertAllEqual( + ildj, + self.evaluate(exp.inverse_log_det_jacobian( + np.array([np.diag(y_mat) for y_mat in y]), + event_ndims=1))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index 31d24aa9ea09007b8db40e4869371b1f62639ac7..bbbec2103aefd3f38a9b734bcd3f2e15fc8bb683 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -29,7 +29,9 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.linalg import linear_operator_diag @@ -540,5 +542,51 @@ class PadDynamicTest(_PadTest, test.TestCase): return False +class TestMoveDimension(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_move_dimension_static_shape(self): + + x = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1]) + + @test_util.run_in_graph_and_eager_modes() + def test_move_dimension_dynamic_shape(self): + + x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + x = array_ops.placeholder_with_default(input=x_, shape=None) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + x_perm = distribution_util.move_dimension(x, -1, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py index 968057331787059240110b90545f70c0ab128aa8..b91a610acf1a9094d612504d63030b3bffb873ac 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py @@ -65,6 +65,16 @@ class SeedStreamTest(test.TestCase): self.assertAllUnique( outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)]) + def testInitFromOtherSeedStream(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(strm1, salt="salt") + strm3 = seed_stream.SeedStream(strm1, salt="another salt") + out1 = [strm1() for _ in range(50)] + out2 = [strm2() for _ in range(50)] + out3 = [strm3() for _ in range(50)] + self.assertAllEqual(out1, out2) + self.assertAllUnique(out1 + out3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py index ce6cf702d522792f1ad26066a3d9be42003a0e3c..9c4dfed83631e9f0815fb674d650cac2e570b923 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py @@ -98,23 +98,21 @@ class StatisticalTestingTest(test.TestCase): num_samples = 5000 # 5000 samples is chosen to be enough to find discrepancies of # size 0.1 or more with assurance 1e-6, as confirmed here: - with self.test_session() as sess: - d = st.min_discrepancy_of_true_means_detectable_by_dkwm( - num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) - d = sess.run(d) - self.assertLess(d, 0.1) + d = st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) + d = self.evaluate(d) + self.assertLess(d, 0.1) # Test that the confidence interval computed for the mean includes # 0.5 and excludes 0.4 and 0.6. - with self.test_session() as sess: - samples = rng.uniform(size=num_samples).astype(np.float32) - (low, high) = st.true_mean_confidence_interval_by_dkwm( - samples, 0., 1., error_rate=1e-6) - low, high = sess.run([low, high]) - self.assertGreater(low, 0.4) - self.assertLess(low, 0.5) - self.assertGreater(high, 0.5) - self.assertLess(high, 0.6) + samples = rng.uniform(size=num_samples).astype(np.float32) + (low, high) = st.true_mean_confidence_interval_by_dkwm( + samples, 0., 1., error_rate=1e-6) + low, high = self.evaluate([low, high]) + self.assertGreater(low, 0.4) + self.assertLess(low, 0.5) + self.assertGreater(high, 0.5) + self.assertLess(high, 0.6) def test_dkwm_mean_one_sample_assertion(self): rng = np.random.RandomState(seed=0) @@ -123,21 +121,45 @@ class StatisticalTestingTest(test.TestCase): # Test that the test assertion agrees that the mean of the standard # uniform distribution is 0.5. samples = rng.uniform(size=num_samples).astype(np.float32) - with self.test_session() as sess: - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.5, false_fail_rate=1e-6)) - - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is not 0.4. - with self.assertRaisesOpError("Mean confidence interval too high"): - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.4, false_fail_rate=1e-6)) - - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is not 0.6. - with self.assertRaisesOpError("Mean confidence interval too low"): - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.6, false_fail_rate=1e-6)) + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.5, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.4. + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.6. + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.6, false_fail_rate=1e-6)) + + def test_dkwm_mean_in_interval_one_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 5000 + + # Test that the test assertion agrees that the mean of the standard + # uniform distribution is between 0.4 and 0.6. + samples = rng.uniform(size=num_samples).astype(np.float32) + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.4, expected_high=0.6, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.2 and 0.4. + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.2, expected_high=0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.6 and 0.8. + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.6, expected_high=0.8, false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion(self): rng = np.random.RandomState(seed=0) @@ -145,20 +167,18 @@ class StatisticalTestingTest(test.TestCase): # 4000 samples is chosen to be enough to find discrepancies of # size 0.2 or more with assurance 1e-6, as confirmed here: - with self.test_session() as sess: - d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( - num_samples, 0., 1., num_samples, 0., 1., - false_fail_rate=1e-6, false_pass_rate=1e-6) - d = sess.run(d) - self.assertLess(d, 0.2) + d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + num_samples, 0., 1., num_samples, 0., 1., + false_fail_rate=1e-6, false_pass_rate=1e-6) + d = self.evaluate(d) + self.assertLess(d, 0.2) # Test that the test assertion agrees that the standard # uniform distribution has the same mean as itself. samples1 = rng.uniform(size=num_samples).astype(np.float32) samples2 = rng.uniform(size=num_samples).astype(np.float32) - with self.test_session() as sess: - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion_beta_2_1_false(self): rng = np.random.RandomState(seed=0) @@ -168,15 +188,14 @@ class StatisticalTestingTest(test.TestCase): # As established above, 4000 samples is enough to find discrepancies # of size 0.2 or more with assurance 1e-6. - with self.test_session() as sess: - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is different from the mean of beta(2, 1). - beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) - with self.assertRaisesOpError("samples1 has a smaller mean"): - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., - beta_high_samples, 0., 1., - false_fail_rate=1e-6)) + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(2, 1). + beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_high_samples, 0., 1., + false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion_beta_1_2_false(self): rng = np.random.RandomState(seed=0) @@ -186,15 +205,14 @@ class StatisticalTestingTest(test.TestCase): # As established above, 4000 samples is enough to find discrepancies # of size 0.2 or more with assurance 1e-6. - with self.test_session() as sess: - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is different from the mean of beta(1, 2). - beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) - with self.assertRaisesOpError("samples2 has a smaller mean"): - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., - beta_low_samples, 0., 1., - false_fail_rate=1e-6)) + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(1, 2). + beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_low_samples, 0., 1., + false_fail_rate=1e-6)) def test_dkwm_argument_validity_checking(self): rng = np.random.RandomState(seed=0) @@ -203,18 +221,17 @@ class StatisticalTestingTest(test.TestCase): # Test that the test library complains if the given samples fall # outside the purported bounds. - with self.test_session() as sess: - with self.assertRaisesOpError("maximum value exceeds expectations"): - sess.run(st.true_mean_confidence_interval_by_dkwm( - samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5)) - with self.assertRaisesOpError("minimum value falls below expectations"): - sess.run(st.true_mean_confidence_interval_by_dkwm( - samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5)) - - # But doesn't complain if they don't. - op = st.true_mean_confidence_interval_by_dkwm( - samples, [[0., 1.]], [[1., 2.]], error_rate=0.5) - _ = sess.run(op) + with self.assertRaisesOpError("maximum value exceeds expectations"): + self.evaluate(st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5)) + with self.assertRaisesOpError("minimum value falls below expectations"): + self.evaluate(st.true_mean_confidence_interval_by_dkwm( + samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5)) + + # But doesn't complain if they don't. + op = st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[1., 2.]], error_rate=0.5) + _ = self.evaluate(op) def test_do_maximum_mean(self): n = 117 @@ -223,10 +240,9 @@ class StatisticalTestingTest(test.TestCase): samples = rng.uniform(size=n).astype(np.float32) # Compute the answer in TF using the code under test - with self.test_session() as sess: - envelope_t = ops.convert_to_tensor(envelope) - max_mean = st._do_maximum_mean(samples, envelope_t, 1) - max_mean = sess.run(max_mean) + envelope_t = ops.convert_to_tensor(envelope) + max_mean = st._do_maximum_mean(samples, envelope_t, 1) + max_mean = self.evaluate(max_mean) # Compute the correct answer for this case in numpy. In this # example, `n` and `envelope` are such that `samples[2]` is the diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..03e26b198ea02ad1bef8bcd2f6076078ecd7df0b --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD @@ -0,0 +1,48 @@ +# Description: +# Internal testing utilities, e.g., computing the correct answer to +# put in a unit test. + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "correlation_matrix_volumes_py", + srcs = [ + "correlation_matrix_volumes_lib.py", + ], + deps = [ + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//third_party/py/numpy", + ], +) + +py_binary( + name = "correlation_matrix_volumes", + srcs = [ + "correlation_matrix_volumes.py", + ], + deps = [ + ":correlation_matrix_volumes_py", + ], +) + +py_test( + name = "correlation_matrix_volumes_test", + size = "medium", + srcs = ["correlation_matrix_volumes_test.py"], + tags = ["no_pip"], + deps = [ + ":correlation_matrix_volumes_py", + # For statistical testing + "//tensorflow/contrib/distributions:distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + ], +) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py new file mode 100644 index 0000000000000000000000000000000000000000..2eab51cd3053ea55f2e03619fd002fbf48251fb1 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Executable to estimate the volume of various sets of correlation matrices. + +See correlation_matrix_volumes_lib.py for purpose and methodology. + +Invocation example: +``` +python correlation_matrix_volumes.py --num_samples 1e7 +``` + +This will compute 10,000,000-sample confidence intervals for the +volumes of several sets of correlation matrices. Which sets, and the +desired statistical significance, are hard-coded in this source file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pprint + +from absl import app +from absl import flags + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr + +FLAGS = flags.FLAGS + +# Float to support giving the number of samples in scientific notation. +# The production run used for the LKJ test used 1e7 samples. +flags.DEFINE_float('num_samples', 1e4, 'Number of samples to use.') + + +def ctv_debatched(det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + # This wrapper undoes the batching in compute_true_volumes, because + # apparently several 5x5x9x1e7 Tensors of float32 can strain RAM. + bounds = {} + for db in det_bounds: + bounds[db] = corr.compute_true_volumes( + [db], dim, num_samples, error_rate=error_rate, seed=seed)[db] + return bounds + + +# The particular bounds in all three of these functions were chosen by +# a somewhat arbitrary walk through an empirical tradeoff, for the +# purpose of testing the LKJ distribution. Setting the determinant +# bound lower +# - Covers more of the testee's sample space, and +# - Increases the probability that the rejection sampler will hit, thus +# - Decreases the relative error (at a fixed sample count) in the +# rejection-based volume estimate; +# but also +# - Increases the variance of the estimator used in the LKJ test. +# This latter variance is also affected by the dimension and the +# tested concentration parameter, and can be compensated for with more +# compute (expensive) or a looser discrepancy limit (unsatisfying). +# The values here are the projection of the points in that test design +# space that ended up getting chosen. +def compute_3x3_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 3, num_samples, error_rate=5e-7, seed=46) + + +def compute_4x4_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 4, num_samples, error_rate=5e-7, seed=47) + + +def compute_5x5_volumes(num_samples): + det_bounds = [0.01, 0.2, 0.25, 0.3, 0.35, 0.4] + return ctv_debatched( + det_bounds, 5, num_samples, error_rate=5e-7, seed=48) + + +def main(_): + full_bounds = {} + full_bounds[3] = compute_3x3_volumes(int(FLAGS.num_samples)) + full_bounds[4] = compute_4x4_volumes(int(FLAGS.num_samples)) + full_bounds[5] = compute_5x5_volumes(int(FLAGS.num_samples)) + pprint.pprint(full_bounds) + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..455e71f00c96e799c4aaae25050c77a9ae36df06 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py @@ -0,0 +1,323 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Estimating the volume of the correlation matrices with bounded determinant. + +Why? Because lkj_test.py tests the sampler for the LKJ distribution +by estimating the same volume another way. + +How? Rejection sampling. Or, more precisely, importance sampling, +proposing from the uniform distribution on symmetric matrices with +diagonal 1s and entries in [-1, 1]. Such a matrix is a correlation +matrix if and only if it is also positive semi-definite. + +The samples can then be converted into a confidence interval on the +volume in question by the [Clopper-Pearson +method](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval), +also implemented here. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import sys + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import uniform +from tensorflow.python.ops.distributions import util +from tensorflow.python.platform import tf_logging + +__all__ = [ + "correlation_matrix_volume_rejection_samples", + "compute_true_volumes", +] + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +optimize = try_import("scipy.optimize") +stats = try_import("scipy.stats") + + +def _psd_mask(x): + """Computes whether each square matrix in the input is positive semi-definite. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + + Returns: + mask: A floating-point `Tensor` of shape `[B1, ... Bn]`. Each + scalar is 1 if the corresponding matrix was PSD, otherwise 0. + """ + # Allegedly + # https://scicomp.stackexchange.com/questions/12979/testing-if-a-matrix-is-positive-semi-definite + # it is more efficient to test for positive semi-definiteness by + # trying to compute the Cholesky decomposition -- the matrix is PSD + # if you succeed and not PSD if you fail. However, TensorFlow's + # Cholesky raises an exception if _any_ of the input matrices are + # not PSD, from which I don't know how to extract _which ones_, so I + # proceed by explicitly computing all the eigenvalues and checking + # whether they are all positive or not. + # + # Also, as was discussed in the answer, it is somewhat dangerous to + # treat SPD-ness as binary in floating-point arithmetic. Cholesky + # factorization can complete and 'look' like everything is fine + # (e.g., O(1) entries and a diagonal of all ones) but the matrix can + # have an exponential condition number. + eigenvalues, _ = linalg_ops.self_adjoint_eig(x) + return math_ops.cast( + math_ops.reduce_min(eigenvalues, axis=-1) >= 0, dtype=x.dtype) + + +def _det_large_enough_mask(x, det_bounds): + """Returns whether the input matches the given determinant limit. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + det_bounds: A floating-point `Tensor` that must broadcast to shape + `[B1, ..., Bn]`, giving the desired lower bound on the + determinants in `x`. + + Returns: + mask: A floating-point `Tensor` of shape [B1, ..., Bn]. Each + scalar is 1 if the corresponding matrix had determinant above + the corresponding bound, otherwise 0. + """ + # For the curious: I wonder whether it is possible and desirable to + # use a Cholesky decomposition-based algorithm for this, since the + # only matrices whose determinant this code cares about will be PSD. + # Didn't figure out how to code that in TensorFlow. + # + # Expert opinion is that it would be about twice as fast since + # Cholesky is roughly half the cost of Gaussian Elimination with + # Partial Pivoting. But this is less of an impact than the switch in + # _psd_mask. + return math_ops.cast( + linalg_ops.matrix_determinant(x) > det_bounds, dtype=x.dtype) + + +def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed): + """Returns a uniformly random `Tensor` of "correlation-like" matrices. + + A "correlation-like" matrix is a symmetric square matrix with all entries + between -1 and 1 (inclusive) and 1s on the main diagonal. Of these, + the ones that are positive semi-definite are exactly the correlation + matrices. + + Args: + num_rows: Python `int` dimension of the correlation-like matrices. + batch_shape: `Tensor` or Python `tuple` of `int` shape of the + batch to return. + dtype: `dtype` of the `Tensor` to return. + seed: Random seed. + + Returns: + matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]` + and dtype `dtype`. Each entry is in [-1, 1], and each matrix + along the bottom two dimensions is symmetric and has 1s on the + main diagonal. + """ + num_entries = num_rows * (num_rows + 1) / 2 + ones = array_ops.ones(shape=[num_entries], dtype=dtype) + # It seems wasteful to generate random values for the diagonal since + # I am going to throw them away, but `fill_triangular` fills the + # diagonal, so I probably need them. + # It's not impossible that it would be more efficient to just fill + # the whole matrix with random values instead of messing with + # `fill_triangular`. Then would need to filter almost half out with + # `matrix_band_part`. + unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed) + tril = util.fill_triangular(unifs) + symmetric = tril + array_ops.matrix_transpose(tril) + diagonal_ones = array_ops.ones( + shape=util.pad(batch_shape, axis=0, back=True, value=num_rows), + dtype=dtype) + return array_ops.matrix_set_diag(symmetric, diagonal_ones) + + +def correlation_matrix_volume_rejection_samples( + det_bounds, dim, sample_shape, dtype, seed): + """Returns rejection samples from trying to get good correlation matrices. + + The proposal being rejected from is the uniform distribution on + "correlation-like" matrices. We say a matrix is "correlation-like" + if it is a symmetric square matrix with all entries between -1 and 1 + (inclusive) and 1s on the main diagonal. Of these, the ones that + are positive semi-definite are exactly the correlation matrices. + + The rejection algorithm, then, is to sample a `Tensor` of + `sample_shape` correlation-like matrices of dimensions `dim` by + `dim`, and check each one for (i) being a correlation matrix (i.e., + PSD), and (ii) having determinant at least the corresponding entry + of `det_bounds`. + + Args: + det_bounds: A `Tensor` of lower bounds on the determinants of + acceptable matrices. The shape must broadcast with `sample_shape`. + dim: A Python `int` dimension of correlation matrices to sample. + sample_shape: Python `tuple` of `int` shape of the samples to + compute, excluding the two matrix dimensions. + dtype: The `dtype` in which to do the computation. + seed: Random seed. + + Returns: + weights: A `Tensor` of shape `sample_shape`. Each entry is 0 if the + corresponding matrix was not a correlation matrix, or had too + small of a determinant. Otherwise, the entry is the + multiplicative inverse of the density of proposing that matrix + uniformly, i.e., the volume of the set of `dim` by `dim` + correlation-like matrices. + volume: The volume of the set of `dim` by `dim` correlation-like + matrices. + """ + with ops.name_scope("rejection_sampler"): + rej_proposals = _uniform_correlation_like_matrix( + dim, sample_shape, dtype, seed=seed) + rej_proposal_volume = 2. ** (dim * (dim - 1) / 2.) + # The density of proposing any given point is 1 / rej_proposal_volume; + # The weight of that point should be scaled by + # 1 / density = rej_proposal_volume. + rej_weights = rej_proposal_volume * _psd_mask( + rej_proposals) * _det_large_enough_mask(rej_proposals, det_bounds) + return rej_weights, rej_proposal_volume + + +def _clopper_pearson_confidence_interval(samples, error_rate): + """Computes a confidence interval for the mean of the given 1-D distribution. + + Assumes (and checks) that the given distribution is Bernoulli, i.e., + takes only two values. This licenses using the CDF of the binomial + distribution for the confidence, which is tighter (for extreme + probabilities) than the DKWM inequality. The method is known as the + [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Assumes: + + - The given samples were drawn iid from the distribution of interest. + + - The given distribution is a Bernoulli, i.e., supported only on + low and high. + + Guarantees: + + - The probability (over the randomness of drawing the given sample) + that the true mean is outside the returned interval is no more + than the given error_rate. + + Args: + samples: `np.ndarray` of samples drawn iid from the distribution + of interest. + error_rate: Python `float` admissible rate of mistakes. + + Returns: + low: Lower bound of confidence interval. + high: Upper bound of confidence interval. + + Raises: + ValueError: If `samples` has rank other than 1 (batch semantics + are not implemented), or if `samples` contains values other than + `low` or `high` (as that makes the distribution not Bernoulli). + """ + # TODO(b/78025336) Migrate this confidence interval function + # to statistical_testing.py. In order to do that + # - Get the binomial CDF from the Binomial distribution + # - Implement scalar root finding in TF. Batch bisection search + # shouldn't be too hard, and is definitely good enough for this + # problem. Batching the Brent algorithm (from scipy) that is used + # here may be more involved, but may also not be necessary---it's + # only used here because scipy made it convenient. In particular, + # robustness is more important than speed here, which may make + # bisection search actively better. + # - The rest is just a matter of rewriting in the appropriate style. + if optimize is None or stats is None: + raise ValueError( + "Scipy is required for computing Clopper-Pearson confidence intervals") + if len(samples.shape) != 1: + raise ValueError("Batch semantics not implemented") + n = len(samples) + low = np.amin(samples) + high = np.amax(samples) + successes = np.count_nonzero(samples - low) + failures = np.count_nonzero(samples - high) + if successes + failures != n: + uniques = np.unique(samples) + msg = ("Purportedly Bernoulli distribution had distinct samples" + " {}, {}, and {}".format(uniques[0], uniques[1], uniques[2])) + raise ValueError(msg) + def p_small_enough(p): + prob = stats.binom.logcdf(successes, n, p) + return prob - np.log(error_rate / 2.) + def p_big_enough(p): + prob = stats.binom.logsf(successes, n, p) + return prob - np.log(error_rate / 2.) + high_p = optimize.brentq( + p_small_enough, float(successes) / n, 1., rtol=1e-9) + low_p = optimize.brentq( + p_big_enough, 0., float(successes) / n, rtol=1e-9) + low_interval = low + (high - low) * low_p + high_interval = low + (high - low) * high_p + return (low_interval, high_interval) + + +def compute_true_volumes( + det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + """Returns confidence intervals for the desired correlation matrix volumes. + + The confidence intervals are computed by the [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Args: + det_bounds: A rank-1 numpy array of lower bounds on the + determinants of acceptable matrices. Entries must be unique. + dim: A Python `int` dimension of correlation matrices to sample. + num_samples: The number of samples to draw. + error_rate: The statistical significance of the returned + confidence intervals. The significance is broadcast: Each + returned interval separately may be incorrect with probability + (under the sample of correlation-like matrices drawn internally) + at most `error_rate`. + seed: Random seed. + + Returns: + bounds: A Python `dict` mapping each determinant bound to the low, high + tuple giving the confidence interval. + """ + bounds = {} + with session.Session() as sess: + rej_weights, _ = correlation_matrix_volume_rejection_samples( + det_bounds, dim, [num_samples, len(det_bounds)], np.float32, seed=seed) + rej_weights = sess.run(rej_weights) + for rw, det in zip(np.rollaxis(rej_weights, 1), det_bounds): + template = ("Estimating volume of {}x{} correlation " + "matrices with determinant >= {}.") + print(template.format(dim, dim, det)) + sys.stdout.flush() + bounds[det] = _clopper_pearson_confidence_interval( + rw, error_rate=error_rate) + return bounds diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8f99300e63871119800a42f122c8321e5986541a --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py @@ -0,0 +1,150 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for correlation_matrix_volumes_lib.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr +from tensorflow.contrib.distributions.python.ops import statistical_testing as st +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.platform import test + + +# NxN correlation matrices are determined by the N*(N-1)/2 +# lower-triangular entries. In addition to being between -1 and 1, +# they must also obey the constraint that the determinant of the +# resulting symmetric matrix is non-negative. In 2x2, we can even +# analytically compute the volume when the determinant is bounded to > +# epsilon, as that boils down to the one lower-triangular entry being +# less than 1 - epsilon in absolute value. +def two_by_two_volume(det_bound): + return 2 * np.sqrt(1.0 - det_bound) + + +# The post +# https://psychometroscar.com/the-volume-of-a-3-x-3-correlation-matrix/ +# derives (with elementary calculus) that the volume (with respect to +# Lebesgue^3 measure) of the set of 3x3 correlation matrices is +# pi^2/2. The same result is also obtained by [1]. +def three_by_three_volume(): + return np.pi**2 / 2. + + +# The volume of the unconstrained set of correlation matrices is also +# the normalization constant of the LKJ distribution from [2]. As +# part of defining the distribution, that reference a derives general +# formula for this volume for all dimensions. A TensorFlow +# computation thereof gave the below result for 4x4: +def four_by_four_volume(): + # This constant computed as math_ops.exp(lkj.log_norm_const(4, [1.0])) + return 11.6973076 + +# [1] Rousseeuw, P. J., & Molenberghs, G. (1994). "The shape of +# correlation matrices." The American Statistician, 48(4), 276-279. + +# [2] Daniel Lewandowski, Dorota Kurowicka, and Harry Joe, "Generating +# random correlation matrices based on vines and extended onion +# method," Journal of Multivariate Analysis 100 (2009), pp 1989-2001. + + +class CorrelationMatrixVolumesTest(test.TestCase): + + def testRejection2D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + exact_volumes = two_by_two_volume(det_bounds) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43) + # shape of rej_weights: [num_samples, 9, 2, 2] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + # Correct the false fail rate due to different broadcasting + false_fail_rate=1.1e-7, false_pass_rate=1e-6), + 0.036) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection3D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = np.array([three_by_three_volume()], dtype=np.float32) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 3, [num_samples, 1], dtype=np.float32, seed=44) + # shape of rej_weights: [num_samples, 1, 3, 3] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 3% relative error + 0.15) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection4D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = [four_by_four_volume()] + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45) + # shape of rej_weights: [num_samples, 1, 4, 4] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 10% relative error + 1.1) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testVolumeEstimation2D(self): + # Test that the confidence intervals produced by + # corr.compte_true_volumes are sound, in the sense of containing + # the exact volume. + num_samples = int(1e5) # Chosen by symmetry with testRejection2D + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + volume_bounds = corr.compute_true_volumes( + det_bounds, 2, num_samples, error_rate=1e-6, seed=47) + exact_volumes = two_by_two_volume(det_bounds) + for det, volume in zip(det_bounds, exact_volumes): + computed_low, computed_high = volume_bounds[det] + self.assertLess(computed_low, volume) + self.assertGreater(computed_high, volume) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index 88ed0127841093cc1a1168d988f14e7bb0277b12..11ca90c4833d84b092f0b43a8f5404e3a11450cd 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -144,7 +144,7 @@ class Autoregressive(distribution_lib.Distribution): `distribution_fn(sample0).event_shape.num_elements()` are both `None`. ValueError: if `num_steps < 1`. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name) as name: self._distribution_fn = distribution_fn self._sample0 = sample0 diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index bf5590cd552a915a3ecfc1912ee530baf79665a6..4714caad69ee4341d259f6677decdd5842931834 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -41,9 +41,6 @@ class BatchReshape(distribution_lib.Distribution): This "meta-distribution" reshapes the batch dimensions of another distribution. - Note: Unlike `tf.reshape`, the `BatchReshape` distribution does not support - `-1` for flattening. - #### Examples ```python @@ -51,7 +48,7 @@ class BatchReshape(distribution_lib.Distribution): dtype = np.float32 dims = 2 - new_batch_shape = [1, 2, 3] + new_batch_shape = [1, 2, -1] old_batch_shape = [6] scale = np.ones(old_batch_shape + [dims], dtype) @@ -85,8 +82,9 @@ class BatchReshape(distribution_lib.Distribution): Args: distribution: The base distribution instance to reshape. Typically an instance of `Distribution`. - batch_shape: Positive `int`-like vector-shaped `Tensor` representing the - new shape of the batch dimensions. + batch_shape: Positive `int`-like vector-shaped `Tensor` representing + the new shape of the batch dimensions. Up to one dimension may contain + `-1`, meaning the remainder of the batch size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -104,31 +102,28 @@ class BatchReshape(distribution_lib.Distribution): ValueError: if `batch_shape` size is not the same as a `distribution.batch_shape` size. """ - parameters = locals() + parameters = dict(locals()) name = name or "BatchReshape" + distribution.name - self._distribution = distribution with ops.name_scope(name, values=[batch_shape]) as name: - self._batch_shape_ = ops.convert_to_tensor( - batch_shape, - dtype=dtypes.int32, - name="batch_shape") - self._batch_shape_static = tensor_util.constant_value(self._batch_shape_) - if self._batch_shape_static is not None: - self._batch_shape_static = np.int32(self._batch_shape_static) - self._runtime_assertions = validate_init_args( - self._distribution, - self._batch_shape_, - validate_args, - self._batch_shape_static) + # The unexpanded batch shape may contain up to one dimension of -1. + self._batch_shape_unexpanded = ops.convert_to_tensor( + batch_shape, dtype=dtypes.int32, name="batch_shape") + validate_init_args_statically(distribution, self._batch_shape_unexpanded) + batch_shape, batch_shape_static, runtime_assertions = calculate_reshape( + distribution.batch_shape_tensor(), self._batch_shape_unexpanded, + validate_args) + self._distribution = distribution + self._batch_shape_ = batch_shape + self._batch_shape_static = batch_shape_static + self._runtime_assertions = runtime_assertions super(BatchReshape, self).__init__( - dtype=self._distribution.dtype, - reparameterization_type=self._distribution.reparameterization_type, + dtype=distribution.dtype, + reparameterization_type=distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( - [self._batch_shape_] + - self._distribution._graph_parents), # pylint: disable=protected-access + [self._batch_shape_unexpanded] + distribution._graph_parents), # pylint: disable=protected-access name=name) @property @@ -140,7 +135,7 @@ class BatchReshape(distribution_lib.Distribution): return array_ops.identity(self._batch_shape_) def _batch_shape(self): - return tensor_shape.TensorShape(self._batch_shape_static) + return self._batch_shape_static def _event_shape_tensor(self): with ops.control_dependencies(self._runtime_assertions): @@ -152,11 +147,13 @@ class BatchReshape(distribution_lib.Distribution): def _sample_n(self, n, seed=None): with ops.control_dependencies(self._runtime_assertions): x = self.distribution.sample(sample_shape=n, seed=seed) - new_shape = array_ops.concat([ - [n], - self.batch_shape_tensor(), - self.event_shape_tensor(), - ], axis=0) + new_shape = array_ops.concat( + [ + [n], + self._batch_shape_unexpanded, + self.event_shape_tensor(), + ], + axis=0) return array_ops.reshape(x, new_shape) def _log_prob(self, x): @@ -213,9 +210,9 @@ class BatchReshape(distribution_lib.Distribution): event_ndims = (array_ops.size(self.event_shape_tensor()) if self.event_shape.ndims is None else self.event_shape.ndims) - batch_ndims = (array_ops.size(self.batch_shape_tensor()) - if self.batch_shape.ndims is None - else self.batch_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) sample_ndims = x_ndims - batch_ndims - event_ndims if isinstance(sample_ndims, int): static_sample_shape = x.shape[:sample_ndims] @@ -238,10 +235,11 @@ class BatchReshape(distribution_lib.Distribution): self.event_shape_tensor(), ], axis=0) result = fn(array_ops.reshape(x, old_shape)) - new_shape = array_ops.concat([ - sample_shape, - self.batch_shape_tensor(), - ], axis=0) + new_shape = array_ops.concat( + [ + sample_shape, + self._batch_shape_unexpanded, + ], axis=0) result = array_ops.reshape(result, new_shape) if (static_sample_shape.ndims is not None and self.batch_shape.ndims is not None): @@ -261,8 +259,7 @@ class BatchReshape(distribution_lib.Distribution): if static_event_shape_list is None: static_event_shape_list = [self.event_shape] new_shape = array_ops.concat( - [self.batch_shape_tensor()] + event_shape_list, - axis=0) + [self._batch_shape_unexpanded] + event_shape_list, axis=0) result = array_ops.reshape(fn(), new_shape) if (self.batch_shape.ndims is not None and self.event_shape.ndims is not None): @@ -281,9 +278,9 @@ class BatchReshape(distribution_lib.Distribution): event_ndims = (array_ops.size(self.event_shape_tensor()) if self.event_shape.ndims is None else self.event_shape.ndims) - batch_ndims = (array_ops.size(self.batch_shape_tensor()) - if self.batch_shape.ndims is None - else self.batch_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) expected_batch_event_ndims = batch_ndims + event_ndims if (isinstance(x_ndims, int) and @@ -355,62 +352,56 @@ class BatchReshape(distribution_lib.Distribution): return runtime_assertions -def validate_init_args( - distribution, - batch_shape, - validate_args, - batch_shape_static): +def calculate_reshape(original_shape, new_shape, validate=False, name=None): + """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" + batch_shape_static = tensor_util.constant_value_as_shape(new_shape) + if batch_shape_static.is_fully_defined(): + return np.int32(batch_shape_static.as_list()), batch_shape_static, [] + with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]): + original_size = math_ops.reduce_prod(original_shape) + implicit_dim = math_ops.equal(new_shape, -1) + size_implicit_dim = ( + original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape))) + new_ndims = array_ops.shape(new_shape) + expanded_new_shape = array_ops.where( # Assumes exactly one `-1`. + implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape) + validations = [] if not validate else [ + check_ops.assert_rank( + original_shape, 1, message="Original shape must be a vector."), + check_ops.assert_rank( + new_shape, 1, message="New shape must be a vector."), + check_ops.assert_less_equal( + math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32), + 1, + message="At most one dimension can be unknown."), + check_ops.assert_positive( + expanded_new_shape, message="Shape elements must be >=-1."), + check_ops.assert_equal( + math_ops.reduce_prod(expanded_new_shape), + original_size, + message="Shape sizes do not match."), + ] + return expanded_new_shape, batch_shape_static, validations + + +def validate_init_args_statically(distribution, batch_shape): """Helper to __init__ which makes or raises assertions.""" - with ops.name_scope(name="validate_init_args", - values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access - runtime_assertions = [] - - if batch_shape.shape.ndims is not None: - if batch_shape.shape.ndims != 1: - raise ValueError("`batch_shape` must be a vector " - "(saw rank: {}).".format( - batch_shape.shape.ndims)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_rank( - batch_shape, - 1, - message="`batch_shape` must be a vector.", - name="assert_batch_shape_is_vector"), - ] - - batch_size_static = np.prod(batch_shape_static) - dist_batch_size_static = ( - None if not distribution.batch_shape.is_fully_defined() - else np.prod(distribution.batch_shape).value) - - if batch_size_static is not None and dist_batch_size_static is not None: - if batch_size_static != dist_batch_size_static: - raise ValueError("`batch_shape` size ({}) must match " - "`distribution.batch_shape` size ({}).".format( - batch_size_static, - dist_batch_size_static)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_equal( - math_ops.reduce_prod(batch_shape), - math_ops.reduce_prod(distribution.batch_shape_tensor()), - message=("`batch_shape` size must match " - "`distributions.batch_shape` size."), - name="assert_batch_size"), - ] - - if batch_shape_static is not None: - if np.any(batch_shape_static < 1): - raise ValueError("`batch_shape` elements must be positive " - "(i.e., larger than zero).") - elif validate_args: - runtime_assertions += [ - check_ops.assert_positive( - batch_shape, - message=("`batch_shape` elements must be positive " - "(i.e., larger than zero)."), - name="assert_batch_shape_positive") - ] - - return runtime_assertions + if batch_shape.shape.ndims is not None: + if batch_shape.shape.ndims != 1: + raise ValueError("`batch_shape` must be a vector " + "(saw rank: {}).".format(batch_shape.shape.ndims)) + + batch_shape_static = tensor_util.constant_value_as_shape(batch_shape) + batch_size_static = batch_shape_static.num_elements() + dist_batch_size_static = distribution.batch_shape.num_elements() + + if batch_size_static is not None and dist_batch_size_static is not None: + if batch_size_static != dist_batch_size_static: + raise ValueError("`batch_shape` size ({}) must match " + "`distribution.batch_shape` size ({}).".format( + batch_size_static, dist_batch_size_static)) + + if batch_shape_static.dims is not None: + if any( + dim.value is not None and dim.value < 1 for dim in batch_shape_static): + raise ValueError("`batch_shape` elements must be >=-1.") diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 51478dbeffaabc58ce3662f25f06bc579e8a407e..e141f8b5c6423bd6cce4d09da6f49d55b3e25a24 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -24,23 +24,27 @@ @@CholeskyOuterProduct @@ConditionalBijector @@Exp +@@FillTriangular @@Gumbel @@Identity @@Inline @@Invert @@Kumaraswamy @@MaskedAutoregressiveFlow +@@MatrixInverseTriL @@Ordered @@Permute @@PowerTransform @@RealNVP @@Reshape +@@ScaleTriL @@Sigmoid @@SinhArcsinh @@SoftmaxCentered @@Softplus @@Softsign @@Square +@@TransformDiagonal @@Weibull @@masked_autoregressive_default_template @@ -63,22 +67,26 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * +from tensorflow.contrib.distributions.python.ops.bijectors.fill_triangular import * from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * +from tensorflow.contrib.distributions.python.ops.bijectors.matrix_inverse_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.ordered import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * +from tensorflow.contrib.distributions.python.ops.bijectors.scale_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * from tensorflow.contrib.distributions.python.ops.bijectors.softsign import * from tensorflow.contrib.distributions.python.ops.bijectors.square import * +from tensorflow.contrib.distributions.python.ops.bijectors.transform_diagonal import * from tensorflow.python.ops.distributions.bijector import * from tensorflow.python.ops.distributions.identity_bijector import Identity diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index b158a51bb022b5e2ea3afda74e97b9dc131665a6..16f959560ce0f171035b3ef0bd80b16dae1cc654 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -234,7 +234,7 @@ class Chain(bijector.Bijector): if not self.bijectors: return ildj - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): @@ -248,12 +248,15 @@ class Chain(bijector.Bijector): if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( event_shape.ndims) else: event_shape = b.inverse_event_shape_tensor(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( - array_ops.size(event_shape)) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ + y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -270,7 +273,7 @@ class Chain(bijector.Bijector): if not self.bijectors: return fldj - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): @@ -283,21 +286,14 @@ class Chain(bijector.Bijector): x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) - event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_static_event_ndims(event_shape.ndims) else: event_shape = b.forward_event_shape_tensor(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( - array_ops.size(event_shape)) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ x = b.forward(x, **kwargs.get(b.name, {})) return fldj - - def _maybe_get_event_ndims_statically(self, event_ndims): - event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically( - event_ndims) - if event_ndims_ is None: - return event_ndims - return event_ndims_ - - diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py new file mode 100644 index 0000000000000000000000000000000000000000..7b06325ead3c5c0923d5a3b915e9524e327f0d42 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py @@ -0,0 +1,148 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as dist_util + + +__all__ = [ + "FillTriangular", +] + + +class FillTriangular(bijector.Bijector): + """Transforms vectors to triangular. + + Triangular matrix elements are filled in a clockwise spiral. + + Given input with shape `batch_shape + [d]`, produces output with + shape `batch_shape + [n, n]`, where + `n = (-1 + sqrt(1 + 8 * d))/2`. + This follows by solving the quadratic equation + `d = 1 + 2 + ... + n = n * (n + 1)/2`. + + #### Example + + ```python + b = tfb.FillTriangular(upper=False) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[4, 0, 0], + # [6, 5, 0], + # [3, 2, 1]] + + b = tfb.FillTriangular(upper=True) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[1, 2, 3], + # [0, 5, 6], + # [0, 0, 4]] + + ``` + """ + + def __init__(self, + upper=False, + validate_args=False, + name="fill_triangular"): + """Instantiates the `FillTriangular` bijector. + + Args: + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._upper = upper + super(FillTriangular, self).__init__( + forward_min_event_ndims=1, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return dist_util.fill_triangular(x, upper=self._upper) + + def _inverse(self, y): + return dist_util.fill_triangular_inverse(y, upper=self._upper) + + def _forward_log_det_jacobian(self, x): + return array_ops.zeros_like(x[..., 0]) + + def _inverse_log_det_jacobian(self, y): + return array_ops.zeros_like(y[..., 0, 0]) + + def _forward_event_shape(self, input_shape): + batch_shape, d = input_shape[:-1], input_shape[-1].value + if d is None: + n = None + else: + n = vector_size_to_square_matrix_size(d, self.validate_args) + return batch_shape.concatenate([n, n]) + + def _inverse_event_shape(self, output_shape): + batch_shape, n1, n2 = (output_shape[:-2], + output_shape[-2].value, + output_shape[-1].value) + if n1 is None or n2 is None: + m = None + elif n1 != n2: + raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2)) + else: + m = n1 * (n1 + 1) / 2 + return batch_shape.concatenate([m]) + + def _forward_event_shape_tensor(self, input_shape_tensor): + batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1] + n = vector_size_to_square_matrix_size(d, self.validate_args) + return array_ops.concat([batch_shape, [n, n]], axis=0) + + def _inverse_event_shape_tensor(self, output_shape_tensor): + batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1] + if self.validate_args: + is_square_matrix = check_ops.assert_equal( + n, output_shape_tensor[-2], message="Matrix must be square.") + with ops.control_dependencies([is_square_matrix]): + n = array_ops.identity(n) + d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype) + return array_ops.concat([batch_shape, [d]], axis=0) + + +def vector_size_to_square_matrix_size(d, validate_args, name=None): + """Convert a vector size to a matrix size.""" + if isinstance(d, (float, int, np.generic, np.ndarray)): + n = (-1 + np.sqrt(1 + 8 * d)) / 2. + if float(int(n)) != n: + raise ValueError("Vector length is not a triangular number.") + return int(n) + else: + with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name: + n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2. + if validate_args: + with ops.control_dependencies([check_ops.assert_equal( + math_ops.to_float(math_ops.to_int32(n)), n, + message="Vector length is not a triangular number")]): + n = array_ops.identity(n) + return math_ops.cast(n, d.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..71903f705232f0c5e5e0b3271550b4ef938c4f9d --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py @@ -0,0 +1,145 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "MatrixInverseTriL", +] + + +class MatrixInverseTriL(bijector.Bijector): + """Computes `g(L) = inv(L)`, where `L` is a lower-triangular matrix. + + `L` must be nonsingular; equivalently, all diagonal entries of `L` must be + nonzero. + + The input must have `rank >= 2`. The input is treated as a batch of matrices + with batch shape `input.shape[:-2]`, where each matrix has dimensions + `input.shape[-2]` by `input.shape[-1]` (hence `input.shape[-2]` must equal + `input.shape[-1]`). + + #### Examples + + ```python + tfd.bijectors.MatrixInverseTriL().forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 0], [-2, 1]], i.e., inv(x) + + tfd.bijectors.MatrixInverseTriL().inverse(y=[[1., 0], [-2, 1]]) + # Result: [[1., 0], [2, 1]], i.e., inv(y). + ``` + + """ + + def __init__(self, validate_args=False, name="matrix_inverse_tril"): + """Instantiates the `MatrixInverseTriL` bijector. + + Args: + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + super(MatrixInverseTriL, self).__init__( + forward_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + with ops.control_dependencies(self._assertions(x)): + shape = array_ops.shape(x) + return linalg_ops.matrix_triangular_solve( + x, linalg_ops.eye(shape[-1], batch_shape=shape[:-2]), lower=True) + + def _inverse(self, y): + return self._forward(y) + + def _forward_log_det_jacobian(self, x): + # Calculation of the Jacobian: + # + # Let X = (x_{ij}), 0 <= i,j < n, be a matrix of indeterminates. Let Z = + # X^{-1} where Z = (z_{ij}). Then + # + # dZ/dx_{ij} = (d/dt | t=0) Y(t)^{-1}, + # + # where Y(t) = X + t*E_{ij} and E_{ij} is the matrix with a 1 in the (i,j) + # entry and zeros elsewhere. By the product rule, + # + # 0 = d/dt [Identity matrix] + # = d/dt [Y Y^{-1}] + # = Y d/dt[Y^{-1}] + dY/dt Y^{-1} + # + # so + # + # d/dt[Y^{-1}] = -Y^{-1} dY/dt Y^{-1} + # = -Y^{-1} E_{ij} Y^{-1}. + # + # Evaluating at t=0, + # + # dZ/dx_{ij} = -Z E_{ij} Z. + # + # Taking the (r,s) entry of each side, + # + # dz_{rs}/dx_{ij} = -z_{ri}z_{sj}. + # + # Now, let J be the Jacobian dZ/dX, arranged as the n^2-by-n^2 matrix whose + # (r*n + s, i*n + j) entry is dz_{rs}/dx_{ij}. Considering J as an n-by-n + # block matrix with n-by-n blocks, the above expression for dz_{rs}/dx_{ij} + # shows that the block at position (r,i) is -z_{ri}Z. Hence + # + # J = -KroneckerProduct(Z, Z), + # det(J) = (-1)^(n^2) (det Z)^(2n) + # = (-1)^n (det X)^(-2n). + with ops.control_dependencies(self._assertions(x)): + return (-2. * math_ops.cast(array_ops.shape(x)[-1], x.dtype.base_dtype) * + math_ops.reduce_sum( + math_ops.log(math_ops.abs(array_ops.matrix_diag_part(x))), + axis=-1)) + + def _assertions(self, x): + if not self.validate_args: + return [] + shape = array_ops.shape(x) + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must have rank at least 2.") + is_square = check_ops.assert_equal( + shape[-2], shape[-1], message="Input must be a square matrix.") + above_diagonal = array_ops.matrix_band_part( + array_ops.matrix_set_diag( + x, array_ops.zeros(shape[:-1], dtype=dtypes.float32)), + 0, -1) + is_lower_triangular = check_ops.assert_equal( + above_diagonal, array_ops.zeros_like(above_diagonal), + message="Input must be lower triangular.") + # A lower triangular matrix is nonsingular iff all its diagonal entries are + # nonzero. + diag_part = array_ops.matrix_diag_part(x) + is_nonsingular = check_ops.assert_none_equal( + diag_part, array_ops.zeros_like(diag_part), + message="Input must have all diagonal entries nonzero.") + return [is_matrix, is_square, is_lower_triangular, is_nonsingular] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..96bd242c634565987678143da3b43db6a9ea4966 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py @@ -0,0 +1,114 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops.bijectors import affine_scalar +from tensorflow.contrib.distributions.python.ops.bijectors import chain +from tensorflow.contrib.distributions.python.ops.bijectors import fill_triangular +from tensorflow.contrib.distributions.python.ops.bijectors import softplus +from tensorflow.contrib.distributions.python.ops.bijectors import transform_diagonal + +__all__ = [ + "ScaleTriL", +] + + +class ScaleTriL(chain.Chain): + """Transforms unconstrained vectors to TriL matrices with positive diagonal. + + This is implemented as a simple `tfb.Chain` of `tfb.FillTriangular` + followed by `tfb.TransformDiagonal`, and provided mostly as a + convenience. The default setup is somewhat opinionated, using a + Softplus transformation followed by a small shift (`1e-5`) which + attempts to avoid numerical issues from zeros on the diagonal. + + #### Examples + + ```python + tfb = tf.contrib.distributions.bijectors + b = tfb.ScaleTriL( + diag_bijector=tfb.Exp(), + diag_shift=None) + b.forward(x=[0., 0., 0.]) + # Result: [[1., 0.], + # [0., 1.]] + b.inverse(y=[[1., 0], + [.5, 2]]) + # Result: [log(2), .5, log(1)] + + # Define a distribution over PSD matrices of shape `[3, 3]`, + # with `1 + 2 + 3 = 6` degrees of freedom. + dist = tfd.TransformedDistribution( + tfd.Normal(tf.zeros(6), tf.ones(6)), + tfb.Chain([tfb.CholeskyOuterProduct(), tfb.ScaleTriL()])) + + # Using an identity transformation, ScaleTriL is equivalent to + # tfb.FillTriangular. + b = tfb.ScaleTriL( + diag_bijector=tfb.Identity(), + diag_shift=None) + + # For greater control over initialization, one can manually encode + # pre- and post- shifts inside of `diag_bijector`. + b = tfb.ScaleTriL( + diag_bijector=tfb.Chain([ + tfb.AffineScalar(shift=1e-3), + tfb.Softplus(), + tfb.AffineScalar(shift=0.5413)]), # softplus_inverse(1.) + # = log(expm1(1.)) = 0.5413 + diag_shift=None) + ``` + """ + + def __init__(self, + diag_bijector=None, + diag_shift=1e-5, + validate_args=False, + name="scale_tril"): + """Instantiates the `ScaleTriL` bijector. + + Args: + diag_bijector: `Bijector` instance, used to transform the output diagonal + to be positive. + Default value: `None` (i.e., `tfb.Softplus()`). + diag_shift: Float value broadcastable and added to all diagonal entries + after applying the `diag_bijector`. Setting a positive + value forces the output diagonal entries to be positive, but + prevents inverting the transformation for matrices with + diagonal entries less than this value. + Default value: `1e-5` (i.e., no shift is applied). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + Default value: `False` (i.e., arguments are not validated). + name: Python `str` name given to ops managed by this object. + Default value: `scale_tril`. + """ + + if diag_bijector is None: + diag_bijector = softplus.Softplus(validate_args=validate_args) + + if diag_shift is not None: + diag_bijector = chain.Chain([affine_scalar.AffineScalar(shift=diag_shift), + diag_bijector]) + + super(ScaleTriL, self).__init__( + [transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector), + fill_triangular.FillTriangular()], + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py new file mode 100644 index 0000000000000000000000000000000000000000..65669fc2bf92ce42f8086f3ea966343d40dd7f97 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import bijector + +__all__ = [ + "TransformDiagonal", +] + + +class TransformDiagonal(bijector.Bijector): + """Applies a Bijector to the diagonal of a matrix. + + #### Example + + ```python + b = tfb.TransformDiagonal(diag_bijector=tfb.Exp()) + + b.forward([[1., 0.], + [0., 1.]]) + # ==> [[2.718, 0.], + [0., 2.718]] + ``` + + """ + + def __init__(self, + diag_bijector, + validate_args=False, + name="transform_diagonal"): + """Instantiates the `TransformDiagonal` bijector. + + Args: + diag_bijector: `Bijector` instance used to transform the diagonal. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._diag_bijector = diag_bijector + super(TransformDiagonal, self).__init__( + forward_min_event_ndims=2, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + diag = self._diag_bijector.forward(array_ops.matrix_diag_part(x)) + return array_ops.matrix_set_diag(x, diag) + + def _inverse(self, y): + diag = self._diag_bijector.inverse(array_ops.matrix_diag_part(y)) + return array_ops.matrix_set_diag(y, diag) + + def _forward_log_det_jacobian(self, x): + # We formulate the Jacobian with respect to the flattened matrices + # `vec(x)` and `vec(y)`. Suppose for notational convenience that + # the first `n` entries of `vec(x)` are the diagonal of `x`, and + # the remaining `n**2-n` entries are the off-diagonals in + # arbitrary order. Then the Jacobian is a block-diagonal matrix, + # with the Jacobian of the diagonal bijector in the first block, + # and the identity Jacobian for the remaining entries (since this + # bijector acts as the identity on non-diagonal entries): + # + # J_vec(x) (vec(y)) = + # ------------------------------- + # | J_diag(x) (diag(y)) 0 | n entries + # | | + # | 0 I | n**2-n entries + # ------------------------------- + # n n**2-n + # + # Since the log-det of the second (identity) block is zero, the + # overall log-det-jacobian is just the log-det of first block, + # from the diagonal bijector. + # + # Note that for elementwise operations (exp, softplus, etc) the + # first block of the Jacobian will itself be a diagonal matrix, + # but our implementation does not require this to be true. + return self._diag_bijector.forward_log_det_jacobian( + array_ops.matrix_diag_part(x), event_ndims=1) + + def _inverse_log_det_jacobian(self, y): + return self._diag_bijector.inverse_log_det_jacobian( + array_ops.matrix_diag_part(y), event_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py index 12d16031783b78dc3ea6273af77c1eaeb77ca94e..e4944beedcbca09b5eabd4daf1445ce4503b1c80 100644 --- a/tensorflow/contrib/distributions/python/ops/binomial.py +++ b/tensorflow/contrib/distributions/python/ops/binomial.py @@ -163,7 +163,7 @@ class Binomial(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._total_count = self._maybe_assert_valid_total_count( ops.convert_to_tensor(total_count, name="total_count"), diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index daacfe657fe154dce8d0db98894fe8b73546c476..23b6a83c17d58652001543047febeebabba0c69f 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -120,7 +120,7 @@ class Cauchy(distribution.Distribution): Raises: TypeError: if `loc` and `scale` have different `dtype`. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index c77c5fd20895a6220604d76a95a152a22cd3d914..686ae1ba74641e2b7b76667e512fa6453477a8da 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -83,7 +83,7 @@ class Chi2(gamma.Gamma): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) # Even though all stats of chi2 are defined for valid parameters, this is # not true in the parent class "gamma." therefore, passing # allow_nan_stats=True @@ -119,7 +119,7 @@ class Chi2WithAbsDf(Chi2): validate_args=False, allow_nan_stats=True, name="Chi2WithAbsDf"): - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[df]) as name: super(Chi2WithAbsDf, self).__init__( df=math_ops.floor( diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 10b45361358b40a3c8fd725f27ad84ef9b8a37f5..3598c8d23ea9007fb359ae4931738fb61ede4ccc 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import conditional_distribution from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import transformed_distribution @@ -106,7 +105,7 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access @@ -131,7 +130,7 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access @@ -220,14 +219,14 @@ class ConditionalTransformedDistribution( inv_cdf = self.distribution.quantile(value, **distribution_kwargs) return self.bijector.forward(inv_cdf, **bijector_kwargs) - def _maybe_get_event_ndims_statically(self): + def _maybe_get_static_event_ndims(self): if self.event_shape.ndims is not None: return self.event_shape.ndims event_ndims = array_ops.size(self.event_shape_tensor()) - static_event_ndims = tensor_util.constant_value(event_ndims) + event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) - if static_event_ndims is not None: - return static_event_ndims + if event_ndims_ is not None: + return event_ndims_ return event_ndims diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index a42350430e98515e521ce357bf5a87ff2daefedc..c44c76a133817640449ba126bb8ca25abadba5e6 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -86,7 +86,7 @@ class _BaseDeterministic(distribution.Distribution): Raises: ValueError: If `loc` is a scalar. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, atol, rtol]) as name: loc = ops.convert_to_tensor(loc, name="loc") if is_vector and validate_args: diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 289e1d50e1146a641c0cc433ece3465aed73b1c2..6959b3e8775d2dd488b4ee3252d143ef376d58f9 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -21,12 +21,19 @@ from __future__ import print_function from tensorflow.contrib import linalg from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib + +# The following two lines are redundant, in a sense. The first enables +# good coding practice *within* this file (`util.prefer_static_value` +# rather than `prefer_static_value`). The second ensures that users +# also get the core utils when they import this file. +from tensorflow.python.ops.distributions import util from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import @@ -484,3 +491,75 @@ def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, def static_value(x): """Returns the static value of a `Tensor` or `None`.""" return tensor_util.constant_value(ops.convert_to_tensor(x)) + + +def move_dimension(x, source_idx, dest_idx): + """Move a single tensor dimension within its shape. + + This is a special case of `tf.transpose()`, which applies + arbitrary permutations to tensor dimensions. + + Args: + x: Tensor of rank `ndims`. + source_idx: Integer index into `x.shape` (negative indexing is + supported). + dest_idx: Integer index into `x.shape` (negative indexing is + supported). + + Returns: + x_perm: Tensor of rank `ndims`, in which the dimension at original + index `source_idx` has been moved to new index `dest_idx`, with + all other dimensions retained in their original order. + + Example: + + ```python + x = tf.placeholder(shape=[200, 30, 4, 1, 6]) + x_perm = _move_dimension(x, 1, 1) # no-op + x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6] + x_perm = _move_dimension(x, 0, -2) # equivalent to previous + x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1] + ``` + """ + ndims = util.prefer_static_rank(x) + if isinstance(source_idx, int): + dtype = dtypes.int32 + else: + dtype = dtypes.as_dtype(source_idx.dtype) + + # Handle negative indexing. Since ndims might be dynamic, this makes + # source_idx and dest_idx also possibly dynamic. + if source_idx < 0: + source_idx = ndims + source_idx + if dest_idx < 0: + dest_idx = ndims + dest_idx + + # Construct the appropriate permutation of dimensions, depending + # whether the source is before or after the destination. + def move_left_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, dest_idx, dtype=dtype), + [source_idx], + math_ops.range(dest_idx, source_idx, dtype=dtype), + math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0)) + + def move_right_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, source_idx, dtype=dtype), + math_ops.range(source_idx+1, dest_idx+1, dtype=dtype), + [source_idx], + math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0)) + + def x_permuted(): + return array_ops.transpose( + x, perm=smart_cond.smart_cond(source_idx < dest_idx, + move_right_permutation, + move_left_permutation)) + + # One final conditional to handle the special case where source + # and destination indices are equal. + return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx), + lambda: x, + x_permuted) diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py index 53dd42f4c83fcea0ec5b1374c8e3109ebe1dd127..e1e42ee95d200df30c2c8a53a89cb5b7e9c4d17c 100644 --- a/tensorflow/contrib/distributions/python/ops/geometric.py +++ b/tensorflow/contrib/distributions/python/ops/geometric.py @@ -85,7 +85,7 @@ class Geometric(distribution.Distribution): name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index 2c261073ee16462599740cb241108bfe08c773ec..9d94fd11c62ce6ecd3d7daee35447bece2b4b2fb 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -124,7 +124,7 @@ class _Gumbel(distribution.Distribution): Raises: TypeError: if loc and scale are different dtypes. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index d0df2befd6e46ca93e5a0b5d1cb5407d6719c7f2..9c96254d1c0a593b955231132330931ff5f4ad07 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -105,7 +105,7 @@ class HalfNormal(distribution.Distribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index fbde55ef310de1d926b8ddd503499fbed4809373..cd6eaa8407477b4ed92f169bc0d2d80644d7c956 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -116,7 +116,7 @@ class Independent(distribution_lib.Distribution): ValueError: if `reinterpreted_batch_ndims` exceeds `distribution.batch_ndims` """ - parameters = locals() + parameters = dict(locals()) name = name or "Independent" + distribution.name self._distribution = distribution with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 502bd4f493337bab180129cd0ddfaf5a76a0ca4e..208057b34db2881b5c9c2adb102d02a87a333007 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -125,7 +125,7 @@ class InverseGamma(distribution.Distribution): Raises: TypeError: if `concentration` and `rate` are different dtypes. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration, rate]) as name: with ops.control_dependencies([ check_ops.assert_positive(concentration), @@ -280,7 +280,7 @@ class InverseGammaWithSoftplusConcentrationRate(InverseGamma): validate_args=False, allow_nan_stats=True, name="InverseGammaWithSoftplusConcentrationRate"): - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration, rate]) as name: super(InverseGammaWithSoftplusConcentrationRate, self).__init__( concentration=nn.softplus(concentration, diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 66682b2ff5493f8565410138e770b45ffc6b5d77..0ff989fc952c6fb3f54dad9a943eb36a0494a3be 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -31,7 +31,6 @@ from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import uniform -from tensorflow.python.util.tf_export import tf_export __all__ = [ "Kumaraswamy", @@ -59,7 +58,6 @@ def _harmonic_number(x): return math_ops.digamma(x + one) - math_ops.digamma(one) -@tf_export("distributions.Kumaraswamy") class Kumaraswamy(transformed_distribution.TransformedDistribution): """Kumaraswamy distribution. diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index c83b5bc2e3a8c56f5c52d063a7d0d399be1c1870..27aa863440574eb0cdb5c7ae326e877d472999ad 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -119,7 +119,7 @@ class Logistic(distribution.Distribution): Raises: TypeError: if loc and scale are different dtypes. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index 2ef294af2e8bc9beff735ec2e0fd6b619ce96176..bfb53a06c011cec60cf5b2132e4b1106128a1ece 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -116,7 +116,7 @@ class Mixture(distribution.Distribution): matching static batch shapes, or all components do not have matching static event shapes. """ - parameters = locals() + parameters = dict(locals()) if not isinstance(cat, categorical.Categorical): raise TypeError("cat must be a Categorical distribution, but saw: %s" % cat) diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 0b1301e551728f74bb0048d2dcf3c356ae110c75..112eefd3691815ead19d59bc3aef5909b27ed169 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -130,7 +130,7 @@ class MixtureSameFamily(distribution.Distribution): ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index e3236c2db93695a5e007bba9a1414773f3935f2e..d2beb2aff0481eb4ec3a3abbf44fad5efff8eedd 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -193,7 +193,7 @@ class MultivariateNormalDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): @@ -224,7 +224,7 @@ class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag): validate_args=False, allow_nan_stats=True, name="MultivariateNormalDiagWithSoftplusScale"): - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[scale_diag]) as name: super(MultivariateNormalDiagWithSoftplusScale, self).__init__( loc=loc, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 2f6a6f198cbcfbdcbd0993d3074ddde1c389585f..5117379b047f5e510a8a1a5490ddf76ee93d9d74 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -215,7 +215,7 @@ class MultivariateNormalDiagPlusLowRank( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 5d06a396fe7a3b87cabb9c3081da45246854089f..57f47db50c496f1e3e80d8177560b1bab594eb56 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -155,7 +155,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): Raises: ValueError: if neither `loc` nor `covariance_matrix` are specified. """ - parameters = locals() + parameters = dict(locals()) # Convert the covariance_matrix up to a scale_tril and call MVNTriL. with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 44c92312c7dc758500051f89923ec9fafe850c0e..6a0383db02555274239ee0b1845f24a705270d84 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -170,7 +170,7 @@ class MultivariateNormalLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index d6f8b731cbeed5fed3b43365e7c668d0434a267e..c809ef3c1cb5b8b9cd892b98d81e57710807d0aa 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -179,7 +179,7 @@ class MultivariateNormalTriL( Raises: ValueError: if neither `loc` nor `scale_tril` are specified. """ - parameters = locals() + parameters = dict(locals()) def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) if loc is None and scale_tril is None: diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py index eeaf9c0a5ebc1323e137ff73f82588f6907031c7..2bd11e24b315e044624344580108a232d1b6da89 100644 --- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py +++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py @@ -90,7 +90,7 @@ class NegativeBinomial(distribution.Distribution): name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index 305b138fdc2318523ee078195213caf865d96b4d..3e44c10fab726ad1299cc852a5e1391fecb8b390 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -115,7 +115,7 @@ class OneHotCategorical(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( name=name, logits=logits, probs=probs, validate_args=validate_args, diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index a84aad6fc9372395ac021fa3aa006ddf9272e6a9..04de8106ee0c06f4bc888964e053eb3123f3dab3 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -93,7 +93,7 @@ class Poisson(distribution.Distribution): TypeError: if `rate` is not a float-type. TypeError: if `log_rate` is not a float-type. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[rate]) as name: if (rate is None) == (log_rate is None): raise ValueError("Must specify exactly one of `rate` and `log_rate`.") diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 19c99dcee92978e938a73af9be445cd098e5fe90..7b10ba998f0ceac37571524ce858bbd4c87455fe 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -255,7 +255,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): TypeError: if `quadrature_grid` and `quadrature_probs` have different base `dtype`. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: if loc is not None: loc = ops.convert_to_tensor(loc, name="loc") diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index eb94760ad71f5babaedaafd3f7990b40aaad85c2..5ac6c34b538016af376f53aa5a889e78c1f65f5f 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -263,7 +263,7 @@ class QuantizedDistribution(distributions.Distribution): `Distribution` or continuous. NotImplementedError: If the base distribution does not implement `cdf`. """ - parameters = locals() + parameters = dict(locals()) values = ( list(distribution.parameters.values()) + [low, high]) diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index 84c8d29072c2f1f3888329638c4695bccf70eab7..4182ca2b56ea80dba71787b006a1652e0f979694 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -165,7 +165,7 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs, temperature]) as name: with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 325f41e37c928ba8e81e45e63a7f7f8126bc80f8..5414f347cd65e2d3327d1934cbc7a91e7f780fc5 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -162,7 +162,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs, temperature]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py index 056d349688511e19a4fa3d58a5b3c1c8355671a3..cf505ac627b62ae0a3d1ec1ce2a237c3c2ff1b74 100644 --- a/tensorflow/contrib/distributions/python/ops/seed_stream.py +++ b/tensorflow/contrib/distributions/python/ops/seed_stream.py @@ -169,7 +169,7 @@ class SeedStream(object): and TensorFlow Probability code base. See class docstring for rationale. """ - self._seed = seed + self._seed = seed.original_seed if isinstance(seed, SeedStream) else seed self._salt = salt self._counter = 0 diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index 03828fa61277eeaf7ce90de8023b4ed91f6cc4dc..a764544932cea8a624820153e383595fec9d7fc6 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -132,7 +132,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale, skewness, tailweight]) as name: diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py index 9c69435fac109914ff29b307dfad105f62849339..c25e8c51d7705b641699fb05623c7b0fb4950e1b 100644 --- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py +++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py @@ -140,6 +140,7 @@ __all__ = [ "assert_true_mean_equal_by_dkwm", "min_discrepancy_of_true_means_detectable_by_dkwm", "min_num_samples_for_dkwm_mean_test", + "assert_true_mean_in_interval_by_dkwm", "assert_true_mean_equal_by_dkwm_two_sample", "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample", "min_num_samples_for_dkwm_mean_two_sample_test", @@ -209,17 +210,17 @@ def _maximum_mean(samples, envelope, high, name=None): separately. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `envelope` and `high`. - envelope: Floating-point tensor of sizes of admissible CDF + envelope: Floating-point `Tensor` of sizes of admissible CDF envelopes (i.e., the `eps` above). - high: Floating-point tensor of upper bounds on the distributions' - supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. `samples <= high`. name: A name for this operation (optional). Returns: - bound: Floating-point tensor of upper bounds on the true means. + bound: Floating-point `Tensor` of upper bounds on the true means. Raises: InvalidArgumentError: If some `sample` is found to be larger than @@ -254,17 +255,17 @@ def _minimum_mean(samples, envelope, low, name=None): separately. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `envelope` and `low`. - envelope: Floating-point tensor of sizes of admissible CDF + envelope: Floating-point `Tensor` of sizes of admissible CDF envelopes (i.e., the `eps` above). - low: Floating-point tensor of lower bounds on the distributions' - supports. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. `samples >= low`. name: A name for this operation (optional). Returns: - bound: Floating-point tensor of lower bounds on the true means. + bound: Floating-point `Tensor` of lower bounds on the true means. Raises: InvalidArgumentError: If some `sample` is found to be smaller than @@ -300,12 +301,12 @@ def _dkwm_cdf_envelope(n, error_rate, name=None): probability above. Args: - n: Tensor of numbers of samples drawn. - error_rate: Floating-point tensor of admissible rates of mistakes. + n: `Tensor` of numbers of samples drawn. + error_rate: Floating-point `Tensor` of admissible rates of mistakes. name: A name for this operation (optional). Returns: - eps: Tensor of maximum distances the true CDF can be from the + eps: `Tensor` of maximum distances the true CDF can be from the empirical CDF. This scales as `O(sqrt(-log(error_rate)))` and as `O(1 / sqrt(n))`. The shape is the broadcast of `n` and `error_rate`. @@ -324,8 +325,8 @@ def _check_shape_dominates(samples, parameters): sample counts end up inflated. Args: - samples: A Tensor whose shape is to be protected against broadcasting. - parameters: A list of Tensors who are parameters for the statistical test. + samples: A `Tensor` whose shape is to be protected against broadcasting. + parameters: A list of `Tensor`s who are parameters for the statistical test. Returns: samples: Return original `samples` with control dependencies attached @@ -369,19 +370,23 @@ def true_mean_confidence_interval_by_dkwm( members. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low` and `high`. - low: Floating-point tensor of lower bounds on the distributions' + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - error_rate: *Scalar* admissible total rate of mistakes. + error_rate: *Scalar* floating-point `Tensor` admissible total rate + of mistakes. name: A name for this operation (optional). Returns: - low: A floating-point tensor of stochastic lower bounds on the true means. - high: A floating-point tensor of stochastic upper bounds on the true means. + low: A floating-point `Tensor` of stochastic lower bounds on the + true means. + high: A floating-point `Tensor` of stochastic upper bounds on the + true means. """ with ops.name_scope( name, "true_mean_confidence_interval_by_dkwm", @@ -436,15 +441,17 @@ def assert_true_mean_equal_by_dkwm( the assertion will insist on stronger evidence to fail any one member. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low` and `high`. - low: Floating-point tensor of lower bounds on the distributions' + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - expected: Floating-point tensor of expected true means. - false_fail_rate: *Scalar* admissible total rate of mistakes. + expected: Floating-point `Tensor` of expected true means. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. name: A name for this operation (optional). Returns: @@ -454,20 +461,8 @@ def assert_true_mean_equal_by_dkwm( with ops.name_scope( name, "assert_true_mean_equal_by_dkwm", [samples, low, high, expected, false_fail_rate]): - samples = ops.convert_to_tensor(samples, name="samples") - low = ops.convert_to_tensor(low, name="low") - high = ops.convert_to_tensor(high, name="high") - expected = ops.convert_to_tensor(expected, name="expected") - false_fail_rate = ops.convert_to_tensor( - false_fail_rate, name="false_fail_rate") - samples = _check_shape_dominates(samples, [low, high, expected]) - min_mean, max_mean = true_mean_confidence_interval_by_dkwm( - samples, low, high, error_rate=false_fail_rate) - less_op = check_ops.assert_less( - min_mean, expected, message="Mean confidence interval too high") - with ops.control_dependencies([less_op]): - return check_ops.assert_greater( - max_mean, expected, message="Mean confidence interval too low") + return assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected, expected, false_fail_rate) def min_discrepancy_of_true_means_detectable_by_dkwm( @@ -487,30 +482,35 @@ def min_discrepancy_of_true_means_detectable_by_dkwm( with the same `false_pass_rate`. Args: - n: Tensor of numbers of samples to be drawn from the distributions + n: `Tensor` of numbers of samples to be drawn from the distributions of interest. - low: Floating-point tensor of lower bounds on the distributions' + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - discr: Tensor of lower bounds on the distances between true + discr: `Tensor` of lower bounds on the distances between true means detectable by a DKWM-based test. For each batch member `i`, of `K` total, drawing `n[i]` samples from some scalar distribution supported on `[low[i], high[i]]` is enough to detect a difference in means of size `discr[i]` or more. Specifically, we guarantee that (a) if the true mean is the expected - mean, `assert_true_mean_equal_by_dkwm` will fail with probability at - most `false_fail_rate / K` (which amounts to `false_fail_rate` if - applied to the whole batch at once), and (b) if the true mean - differs from the expected mean by at least `discr[i]`, - `assert_true_mean_equal_by_dkwm` will pass with probability at most - `false_pass_rate`. + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. The detectable discrepancy scales as @@ -558,17 +558,19 @@ def min_num_samples_for_dkwm_mean_test( on a scalar distribution supported on `[low, high]`. Args: - discrepancy: Floating-point tensor of desired upper limits on mean + discrepancy: Floating-point `Tensor` of desired upper limits on mean differences that may go undetected with probability higher than `1 - false_pass_rate`. - low: Tensor of lower bounds on the distributions' support. - high: Tensor of upper bounds on the distributions' support. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + low: `Tensor` of lower bounds on the distributions' support. + high: `Tensor` of upper bounds on the distributions' support. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - n: Tensor of numbers of samples to be drawn from the distributions + n: `Tensor` of numbers of samples to be drawn from the distributions of interest. The `discrepancy`, `low`, and `high` tensors must have @@ -578,12 +580,15 @@ def min_num_samples_for_dkwm_mean_test( some scalar distribution supported on `[low[i], high[i]]` is enough to detect a difference in means of size `discrepancy[i]` or more. Specifically, we guarantee that (a) if the true mean is the expected - mean, `assert_true_mean_equal_by_dkwm` will fail with probability at - most `false_fail_rate / K` (which amounts to `false_fail_rate` if - applied to the whole batch at once), and (b) if the true mean - differs from the expected mean by at least `discrepancy[i]`, - `assert_true_mean_equal_by_dkwm` will pass with probability at most - `false_pass_rate`. + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discrepancy[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. The required number of samples scales as `O((high[i] - low[i])**2)`, `O(-log(false_fail_rate/K))`, @@ -610,6 +615,76 @@ def min_num_samples_for_dkwm_mean_test( return math_ops.maximum(n1, n2) +def assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected_low, expected_high, + false_fail_rate=1e-6, name=None): + """Asserts the mean of the given distribution is in the given interval. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the mean of the distribution from which the given samples are + drawn is _outside_ the given interval with statistical significance + `false_fail_rate` or stronger, otherwise passes. If you also want + to check that you are gathering enough evidence that a pass is not + spurious, see `min_num_samples_for_dkwm_mean_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples: Floating-point `Tensor` of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. + expected_low: Floating-point `Tensor` of lower bounds on the + expected true means. + expected_high: Floating-point `Tensor` of upper bounds on the + expected true means. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any expected mean + interval does not overlap with the corresponding confidence + interval. + """ + with ops.name_scope( + name, "assert_true_mean_in_interval_by_dkwm", + [samples, low, high, expected_low, expected_high, false_fail_rate]): + samples = ops.convert_to_tensor(samples, name="samples") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + expected_low = ops.convert_to_tensor(expected_low, name="expected_low") + expected_high = ops.convert_to_tensor(expected_high, name="expected_high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + samples = _check_shape_dominates( + samples, [low, high, expected_low, expected_high]) + min_mean, max_mean = true_mean_confidence_interval_by_dkwm( + samples, low, high, false_fail_rate) + # Assert that the interval [min_mean, max_mean] intersects the + # interval [expected_low, expected_high]. This is true if + # max_mean >= expected_low and min_mean <= expected_high. + # By DeMorgan's law, that's also equivalent to + # not (max_mean < expected_low or min_mean > expected_high), + # which is a way of saying the two intervals are not disjoint. + check_confidence_interval_can_intersect = check_ops.assert_greater_equal( + max_mean, expected_low, message="Confidence interval does not " + "intersect: true mean smaller than expected") + with ops.control_dependencies([check_confidence_interval_can_intersect]): + return check_ops.assert_less_equal( + min_mean, expected_high, message="Confidence interval does not " + "intersect: true mean greater than expected") + + def assert_true_mean_equal_by_dkwm_two_sample( samples1, low1, high1, samples2, low2, high2, false_fail_rate=1e-6, name=None): @@ -630,23 +705,26 @@ def assert_true_mean_equal_by_dkwm_two_sample( the assertion will insist on stronger evidence to fail any one member. Args: - samples1: Floating-point tensor of samples from the + samples1: Floating-point `Tensor` of samples from the distribution(s) A. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low1`, `high1`, `low2`, and `high2`. - low1: Floating-point tensor of lower bounds on the supports of the + The support is bounded: `low1 <= samples1 <= high1`. + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - samples2: Floating-point tensor of samples from the + samples2: Floating-point `Tensor` of samples from the distribution(s) B. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low1`, `high1`, `low2`, and `high2`. - low2: Floating-point tensor of lower bounds on the supports of the + The support is bounded: `low2 <= samples2 <= high2`. + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of mistakes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. name: A name for this operation (optional). Returns: @@ -676,20 +754,10 @@ def assert_true_mean_equal_by_dkwm_two_sample( # and sample counts should be valid; however, because the intervals # scale as O(-log(false_fail_rate)), there doesn't seem to be much # room to win. - min_mean_1, max_mean_1 = true_mean_confidence_interval_by_dkwm( - samples1, low1, high1, false_fail_rate / 2.) min_mean_2, max_mean_2 = true_mean_confidence_interval_by_dkwm( samples2, low2, high2, false_fail_rate / 2.) - # I want to assert - # not (max_mean_1 < min_mean_2 or min_mean_1 > max_mean_2), - # but I think I only have and-combination of asserts, so use DeMorgan. - check_confidence_intervals_can_intersect = check_ops.assert_greater_equal( - max_mean_1, min_mean_2, message="Confidence intervals do not " - "intersect: samples1 has a smaller mean than samples2") - with ops.control_dependencies([check_confidence_intervals_can_intersect]): - return check_ops.assert_less_equal( - min_mean_1, max_mean_2, message="Confidence intervals do not " - "intersect: samples2 has a smaller mean than samples1") + return assert_true_mean_in_interval_by_dkwm( + samples1, low1, high1, min_mean_2, max_mean_2, false_fail_rate / 2.) def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( @@ -710,22 +778,24 @@ def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( with the same `false_pass_rate`. Args: - n1: Tensor of numbers of samples to be drawn from the distributions A. - low1: Floating-point tensor of lower bounds on the supports of the + n1: `Tensor` of numbers of samples to be drawn from the distributions A. + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - n2: Tensor of numbers of samples to be drawn from the distributions B. - low2: Floating-point tensor of lower bounds on the supports of the + n2: `Tensor` of numbers of samples to be drawn from the distributions B. + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - discr: Tensor of lower bounds on the distances between true means + discr: `Tensor` of lower bounds on the distances between true means detectable by a two-sample DKWM-based test. For each batch member `i`, of `K` total, drawing `n1[i]` samples @@ -776,24 +846,26 @@ def min_num_samples_for_dkwm_mean_two_sample_test( (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). Args: - discrepancy: Floating-point tensor of desired upper limits on mean + discrepancy: Floating-point `Tensor` of desired upper limits on mean differences that may go undetected with probability higher than `1 - false_pass_rate`. - low1: Floating-point tensor of lower bounds on the supports of the + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - low2: Floating-point tensor of lower bounds on the supports of the + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - n1: Tensor of numbers of samples to be drawn from the distributions A. - n2: Tensor of numbers of samples to be drawn from the distributions B. + n1: `Tensor` of numbers of samples to be drawn from the distributions A. + n2: `Tensor` of numbers of samples to be drawn from the distributions B. For each batch member `i`, of `K` total, drawing `n1[i]` samples from scalar distribution A supported on `[low1[i], high1[i]]` and `n2[i]` diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index af6ff8162b173015dca2d568e13d63127af7853a..8d4914e16cd3748e81e3d9b3be8b35f64a1c6f0d 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -395,7 +395,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): ValueError: if `not distribution.is_scalar_batch`. ValueError: if `not distribution.is_scalar_event`. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[mix_loc, temperature]) as name: if not scale or len(scale) < 2: raise ValueError("Must specify list (or list-like object) of scale " diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index e265b5d0f7c10b2782a1a8924babdca9b986f622..a75b3f3df1f2867f214f47051fa358b79a52a35e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -175,7 +175,7 @@ class VectorExponentialDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index 89136d6760bb663b5ff86a77c5945ce900f072b9..a7d4c55be93f6190ae4d6976030190f27dcfe48f 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -175,7 +175,7 @@ class VectorExponentialLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 8dd983b750d9b39775e570800006011f4968f7f3..4a53e7a621f27382d2995798f724392d34459670 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -210,7 +210,7 @@ class VectorLaplaceDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name): with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index ec485c95c15da2794b67d2699d2bdd9db97bb6c4..0566e04fece6f9ca0de6903ce5c424eccbc003cd 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -191,7 +191,7 @@ class VectorLaplaceLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 1438ede26500bca4541fa9b2020ff22d4c071098..bb33cd0762a368eb7e53f1623ede9231e80f0b14 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -163,7 +163,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope( name, diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 7e78ded9df07564126b46b6beeeccf95bf1eef94..21f84dcbdea8b422dd45fadeac1bb8b2804c551f 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -175,7 +175,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) graph_parents = [df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag] with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index 91453fed5d279178a0e062b71dad3b0f957b11b4..88d4280759da7ca685056f4d41cf8dc51393c9f3 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -107,7 +107,7 @@ class _WishartLinearOperator(distribution.Distribution): ValueError: if df < k, where scale operator event shape is `(k, k)` """ - parameters = locals() + parameters = dict(locals()) self._cholesky_input_output_matrices = cholesky_input_output_matrices with ops.name_scope(name) as name: with ops.name_scope("init", values=[df, scale_operator]): @@ -530,7 +530,7 @@ class WishartCholesky(_WishartLinearOperator): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[scale]) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) @@ -646,7 +646,7 @@ class WishartFull(_WishartLinearOperator): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 99abbae03fc14f241dae27f317902f7335819037..0cc764d2208c5b061b7b836bdf57a035f52c6fcf 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -120,7 +120,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:checkpointable", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -131,6 +130,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", + "//tensorflow/python/training/checkpointable:base", ], ) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 0783d1b5d70e502e6edd80b59f37fdd93b413e12..adf92c27ea0a27c5741bcdd175b277462cb28d02 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -31,7 +31,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training import checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.training.saver import BaseSaverBuilder _uid_counter = 0 @@ -106,7 +106,8 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): target_device=target, buffer_size=10, container="", - shared_name=_generate_shared_name("function_buffer_resource")) + shared_name=_generate_shared_name( + "contrib_eager_iterator_function_buffer_resource")) self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long handle=self._buffer_resource_handle, handle_device=self._device) diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 7b123707cc3a26073088cf2c57c6211e831c19fd..68bec9aee894edd60a025ac1cf87ca3e010db842 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,7 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import util as checkpointable_utils class IteratorTest(test.TestCase): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index c1fd9e0ed020beeb722204edf1adfe1dfcf8ff03..1d9371c7ac405dbf0ec40210270b90f2cf9b9a25 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -7,6 +7,8 @@ py_library( name = "examples_pip", deps = [ "//tensorflow/contrib/eager/python/examples/gan:mnist", + "//tensorflow/contrib/eager/python/examples/l2hmc", + "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", "//tensorflow/contrib/eager/python/examples/linear_regression", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index b80c90902353709b7f739585291ec3b5890c27c7..cc9cf53410f641cc3303b4450e9eaa1301904a64 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -227,7 +227,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer, maxval=1., seed=batch_index) - with tfe.GradientTape(persistent=True) as g: + with tf.GradientTape(persistent=True) as g: generated_images = generator(noise) tf.contrib.summary.image( 'generated_images', @@ -306,7 +306,7 @@ def main(_): if __name__ == '__main__': - tfe.enable_eager_execution() + tf.enable_eager_execution() parser = argparse.ArgumentParser() parser.add_argument( diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py index bd35e50c1f434d167c5a8c5aa7d224912523ce28..81ac05e26d23c2fc53f63d64bb28bdea6072e396 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py @@ -111,5 +111,5 @@ class MnistEagerGanBenchmark(tf.test.Benchmark): if __name__ == '__main__': - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..7bdf9053de749af9d09b12ba7b848e21c1fdb8f0 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD @@ -0,0 +1,39 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_library( + name = "neural_nets", + srcs = ["neural_nets.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + ], +) + +py_library( + name = "l2hmc", + srcs = ["l2hmc.py"], + srcs_version = "PY2AND3", + deps = [ + ":neural_nets", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", + ], +) + +cuda_py_test( + name = "l2hmc_test", + size = "large", + srcs = ["l2hmc_test.py"], + additional_deps = [ + ":l2hmc", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py new file mode 100644 index 0000000000000000000000000000000000000000..729d8525fab31ee214178ca1bcb18dbd069f767a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py @@ -0,0 +1,326 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""L2HMC compatible with TensorFlow's eager execution. + +Reference [Generalizing Hamiltonian Monte Carlo with Neural +Networks](https://arxiv.org/pdf/1711.09268.pdf) + +Code adapted from the released TensorFlow graph implementation by original +authors https://github.com/brain-research/l2hmc. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import numpy.random as npr +import tensorflow as tf +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets + + +class Dynamics(tf.keras.Model): + """Dynamics engine of naive L2HMC sampler. + + Args: + x_dim: dimensionality of observed data + loglikelihood_fn: log-likelihood function of conditional probability + n_steps: number of leapfrog steps within each transition + eps: initial value learnable scale of step size + """ + + def __init__(self, x_dim, loglikelihood_fn, n_steps=25, eps=.1): + super(Dynamics, self).__init__() + + self.x_dim = x_dim + self.potential = loglikelihood_fn + self.n_steps = n_steps + + self._construct_time() + self._construct_masks() + + self.position_fn = neural_nets.GenericNet(x_dim, factor=2.) + self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.) + + self.eps = tfe.Variable( + initial_value=eps, name="eps", dtype=tf.float32, trainable=True) + + def apply_transition(self, position): + """Propose a new state and perform the accept or reject step.""" + + # Simulate dynamics both forward and backward; + # Use sampled Bernoulli masks to compute the actual solutions + position_f, momentum_f, accept_prob_f = self.transition_kernel( + position, forward=True) + position_b, momentum_b, accept_prob_b = self.transition_kernel( + position, forward=False) + + # Decide direction uniformly + forward_mask = tf.cast( + tf.random_uniform(shape=[tf.shape(position)[0]]) > .5, tf.float32) + backward_mask = 1. - forward_mask + + # Obtain proposed states + position_post = ( + forward_mask[:, None] * position_f + + backward_mask[:, None] * position_b) + momentum_post = ( + forward_mask[:, None] * momentum_f + + backward_mask[:, None] * momentum_b) + + # Probability of accepting the proposed states + accept_prob = forward_mask * accept_prob_f + backward_mask * accept_prob_b + + # Accept or reject step + accept_mask = tf.cast( + accept_prob > tf.random_uniform(tf.shape(accept_prob)), tf.float32) + reject_mask = 1. - accept_mask + + # Samples after accept/reject step + position_out = ( + accept_mask[:, None] * position_post + reject_mask[:, None] * position) + + return position_post, momentum_post, accept_prob, position_out + + def transition_kernel(self, position, forward=True): + """Transition kernel of augmented leapfrog integrator.""" + + lf_fn = self._forward_lf if forward else self._backward_lf + + # Resample momentum + momentum = tf.random_normal(tf.shape(position)) + position_post, momentum_post = position, momentum + sumlogdet = 0. + # Apply augmented leapfrog steps + for i in range(self.n_steps): + position_post, momentum_post, logdet = lf_fn(position_post, momentum_post, + i) + sumlogdet += logdet + + accept_prob = self._compute_accept_prob(position, momentum, position_post, + momentum_post, sumlogdet) + + return position_post, momentum_post, accept_prob + + def _forward_lf(self, position, momentum, i): + """One forward augmented leapfrog step. See eq (5-6) in paper.""" + + t = self._get_time(i) + mask, mask_inv = self._get_mask(i) + sumlogdet = 0. + + momentum, logdet = self._update_momentum_forward(position, momentum, t) + sumlogdet += logdet + + position, logdet = self._update_position_forward(position, momentum, t, + mask) + sumlogdet += logdet + + position, logdet = self._update_position_forward(position, momentum, t, + mask_inv) + sumlogdet += logdet + + momentum, logdet = self._update_momentum_forward(position, momentum, t) + sumlogdet += logdet + + return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + + def _backward_lf(self, position, momentum, i): + """One backward augmented leapfrog step. See Appendix A in paper.""" + + # Reversed index/sinusoidal time + t = self._get_time(self.n_steps - i - 1) + mask, mask_inv = self._get_mask(self.n_steps - i - 1) + sumlogdet = 0. + + momentum, logdet = self._update_momentum_backward(position, momentum, t) + sumlogdet += logdet + + position, logdet = self._update_position_backward(position, momentum, t, + mask) + sumlogdet += logdet + + position, logdet = self._update_position_backward(position, momentum, t, + mask_inv) + sumlogdet += logdet + + momentum, logdet = self._update_momentum_backward(position, momentum, t) + sumlogdet += logdet + + return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + + def _update_momentum_forward(self, position, momentum, t): + """Update v in the forward leapfrog step.""" + + grad = self.grad_potential(position) + scale, translation, transformed = self.momentum_fn([position, grad, t]) + scale *= .5 * self.eps + transformed *= self.eps + momentum = ( + momentum * tf.exp(scale) - + .5 * self.eps * (tf.exp(transformed) * grad - translation)) + + return momentum, scale + + def _update_position_forward(self, position, momentum, t, mask): + """Update x in the forward leapfrog step.""" + + mask_inv = 1. - mask + scale, translation, transformed = self.position_fn( + [momentum, mask * position, t]) + scale *= self.eps + transformed *= self.eps + position = ( + mask * position + + mask_inv * (position * tf.exp(scale) + self.eps * + (tf.exp(transformed) * momentum + translation))) + + return position, mask_inv * scale + + def _update_momentum_backward(self, position, momentum, t): + """Update v in the backward leapfrog step. Inverting the forward update.""" + + grad = self.grad_potential(position) + scale, translation, transformed = self.momentum_fn([position, grad, t]) + scale *= -.5 * self.eps + transformed *= self.eps + momentum = ( + tf.exp(scale) * (momentum + .5 * self.eps * + (tf.exp(transformed) * grad - translation))) + + return momentum, scale + + def _update_position_backward(self, position, momentum, t, mask): + """Update x in the backward leapfrog step. Inverting the forward update.""" + + mask_inv = 1. - mask + scale, translation, transformed = self.position_fn( + [momentum, mask_inv * position, t]) + scale *= -self.eps + transformed *= self.eps + position = ( + mask_inv * position + mask * tf.exp(scale) * + (position - self.eps * tf.exp(transformed) * momentum + translation)) + + return position, mask * scale + + def _compute_accept_prob(self, position, momentum, position_post, + momentum_post, sumlogdet): + """Compute the prob of accepting the proposed state given old state.""" + + old_hamil = self.hamiltonian(position, momentum) + new_hamil = self.hamiltonian(position_post, momentum_post) + + return tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) + + def _construct_time(self): + """Convert leapfrog step index into sinusoidal time.""" + + self.ts = [] + for i in range(self.n_steps): + t = tf.constant( + [ + np.cos(2 * np.pi * i / self.n_steps), + np.sin(2 * np.pi * i / self.n_steps) + ], + dtype=tf.float32) + self.ts.append(t[None, :]) + + def _get_time(self, i): + """Get sinusoidal time for i-th augmented leapfrog step.""" + + return self.ts[i] + + def _construct_masks(self): + """Construct different binary masks for different time steps.""" + + self.masks = [] + for _ in range(self.n_steps): + idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2] + mask = np.zeros((self.x_dim,)) + mask[idx] = 1. + mask = tf.constant(mask, dtype=tf.float32) + self.masks.append(mask[None, :]) + + def _get_mask(self, i): + """Get binary masks for i-th augmented leapfrog step.""" + + m = self.masks[i] + return m, 1. - m + + def kinetic(self, v): + """Compute the kinetic energy.""" + + return .5 * tf.reduce_sum(v**2, axis=1) + + def hamiltonian(self, position, momentum): + """Compute the overall Hamiltonian.""" + + return self.potential(position) + self.kinetic(momentum) + + def grad_potential(self, position, check_numerics=True): + """Get gradient of potential function at current location.""" + + if not tf.executing_eagerly(): + # TODO(lxuechen): Change this to tfe.gradients_function when it works + grad = tf.gradients(self.potential(position), position)[0] + else: + grad = tfe.gradients_function(self.potential)(position)[0] + + if check_numerics: + return tf.check_numerics(grad, message="gradient of potential") + + return grad + + +# Examples of unnormalized log density/probabilities +def get_scg_energy_fn(): + """Get energy function for 2d strongly correlated Gaussian.""" + + # Avoid recreating tf constants on each invocation of gradients + mu = tf.constant([0., 0.]) + sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]]) + sigma_inv = tf.matrix_inverse(sigma) + + def energy(x): + """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + + xmmu = x - mu + return .5 * tf.diag_part( + tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + + return energy + + +def get_multivariate_gaussian_energy_fn(x_dim=2): + """Get energy function for 2d strongly correlated Gaussian.""" + + mu = tf.random_normal(shape=[x_dim]) + # Lower triangularize and positive diagonal + l = tf.sigmoid( + tf.matrix_band_part(tf.random_normal(shape=[x_dim, x_dim]), -1, 0)) + # Exploit Cholesky decomposition + sigma = tf.matmul(l, tf.transpose(l)) + sigma *= 100. # Small covariance causes extreme numerical instability + sigma_inv = tf.matrix_inverse(sigma) + + def energy(x): + """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + + xmmu = x - mu + return .5 * tf.diag_part( + tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + + return energy diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e33b4cae4c73388dfd78542c9907953f137ad710 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py @@ -0,0 +1,264 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests l2hmc fit to 2D strongly correlated Gaussian executed eagerly.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy.random as npr +import tensorflow as tf +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc + + +def get_default_hparams(): + return tf.contrib.training.HParams( + x_dim=2, + n_samples=200, + n_steps=10, + eps=.1, + n_iters=10, + learning_rate=.0003, + n_warmup_iters=3) + + +# Relevant functions for benchmarking +def compute_loss(dynamics, x, scale=.1, eps=1e-4): + """Compute loss defined in equation (8).""" + + z = tf.random_normal(tf.shape(x)) + x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) + z_, _, z_accept_prob, _ = dynamics.apply_transition(z) + + # Add eps for numerical stability; following released impl + x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps + z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps + + loss = tf.reduce_mean( + (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) + + return loss, x_out + + +def loss_and_grads(dynamics, x, loss_fn=compute_loss): + """Obtain loss value and gradients.""" + + with tf.GradientTape() as tape: + loss_val, x_out = loss_fn(dynamics, x) + grads = tape.gradient(loss_val, dynamics.variables) + + return loss_val, grads, x_out + + +def warmup(dynamics, optimizer, n_iters=1, n_samples=200, loss_fn=compute_loss): + """Warmup optimization to reduce overhead.""" + + samples = tf.random_normal( + shape=[n_samples, dynamics.x_dim], dtype=tf.float32) + + for _ in range(n_iters): + _, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + optimizer.apply_gradients(zip(grads, dynamics.variables)) + + +def fit(dynamics, + samples, + optimizer, + loss_fn=compute_loss, + n_iters=5000, + verbose=True, + logdir=None, + decay_lr=True): + """Fit L2HMC sampler with given log-likelihood function.""" + + if logdir: + summary_writer = tf.contrib.summary.create_file_writer(logdir) + + for i in range(n_iters): + loss, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + # TODO(lxuechen): Proper learning rate decay + if decay_lr: + grads = [grad * .96**(i // 1000) for grad in grads] + optimizer.apply_gradients(zip(grads, dynamics.variables)) + if verbose: + print("Iteration %d: loss %.4f" % (i, loss)) + + if logdir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("loss", loss) + + +class L2hmcTest(tf.test.TestCase): + """Unit tests for l2hmc in both eager and graph mode.""" + + def test_apply_transition(self): + """Testing function `Dynamics.apply_transition` in graph and eager mode.""" + + # Eager mode testing + hparams = get_default_hparams() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim]) + x_, v_, x_accept_prob, x_out = dynamics.apply_transition(samples) + + self.assertEqual(x_.shape, v_.shape) + self.assertEqual(x_out.shape, samples.shape) + self.assertEqual(x_.shape, x_out.shape) + self.assertEqual(x_accept_prob.shape, (hparams.n_samples,)) + + # Graph mode testing + with tf.Graph().as_default(): + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) + x_, v_, x_accept_prob, x_out = dynamics.apply_transition(x) + samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + np_x_, np_v_, np_x_accept_prob, np_x_out = sess.run( + [x_, v_, x_accept_prob, x_out], feed_dict={x: samples}) + + self.assertEqual(np_x_.shape, np_v_.shape) + self.assertEqual(samples.shape, np_x_out.shape) + self.assertEqual(np_x_.shape, np_x_out.shape) + self.assertEqual(np_x_accept_prob.shape, (hparams.n_samples,)) + + +class L2hmcBenchmark(tf.test.Benchmark): + """Eager and graph benchmarks for l2hmc.""" + + def _get_energy_fn(self): + """Get specific energy function according to FLAGS.""" + + if FLAGS.energy_fn == "scg": + energy_fn = l2hmc.get_scg_energy_fn() + elif FLAGS.energy_fn == "multivariate_gaussian": + energy_fn = l2hmc.get_multivariate_gaussian_energy_fn(x_dim=FLAGS.x_dim) + else: + raise ValueError("No such energy function %s" % FLAGS.energy_fn) + + return energy_fn + + def benchmark_graph(self): + """Benchmark Graph performance.""" + + hparams = get_default_hparams() + tf.reset_default_graph() + with tf.Graph().as_default(): + energy_fn = self._get_energy_fn() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=energy_fn, + n_steps=hparams.n_steps, + eps=hparams.eps) + x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) + loss, x_out = compute_loss(dynamics, x) + + global_step = tf.Variable(0., name="global_step", trainable=False) + learning_rate = tf.train.exponential_decay( + hparams.learning_rate, global_step, 1000, 0.96, staircase=True) + optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) + train_op = optimizer.minimize(loss, global_step=global_step) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + # Warmup to reduce initialization effect when timing + samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) + for _ in range(hparams.n_warmup_iters): + _, _, _, _ = sess.run( + [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) + + # Training + start_time = time.time() + for i in range(hparams.n_iters): + samples, loss_np, _, _ = sess.run( + [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) + print("Iteration %d: loss %.4f" % (i, loss_np)) + wall_time = time.time() - start_time + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="graph_train_%s" % ("gpu" + if tf.test.is_gpu_available() else "cpu"), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + def benchmark_eager(self): + self._benchmark_eager() + + def benchmark_eager_defun(self): + self._benchmark_eager(defun=True) + + def _benchmark_eager(self, defun=False): + """Benchmark Eager performance.""" + + hparams = get_default_hparams() + energy_fn = self._get_energy_fn() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=energy_fn, + n_steps=hparams.n_steps, + eps=hparams.eps) + optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) + loss_fn = tfe.defun(compute_loss) if defun else compute_loss + + # Warmup to reduce initialization effect when timing + warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn) + + # Training + samples = tf.random_normal( + shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32) + start_time = time.time() + fit(dynamics, + samples, + optimizer, + loss_fn=loss_fn, + n_iters=hparams.n_iters, + decay_lr=True) + wall_time = time.time() - start_time + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="eager_train_%s%s" % ("gpu" if tf.test.is_gpu_available() else + "cpu", "_defun" if defun else ""), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + del dynamics + del loss_fn + + +if __name__ == "__main__": + tf.flags.DEFINE_string("energy_fn", "scg", + ("The energy function/unnormalized log-probability. " + "Either be `scg` or `multivariate_gaussian`")) + tf.flags.DEFINE_integer("x_dim", 2, "Dimensionality of observation space.") + FLAGS = tf.flags.FLAGS + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py new file mode 100644 index 0000000000000000000000000000000000000000..e230ad5e259df5b450897bd815e901e3934cd293 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py @@ -0,0 +1,84 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Neural nets utility for L2HMC compatible with TensorFlow's eager execution. + +Reference [Generalizing Hamiltonian Monte Carlo with Neural +Networks](https://arxiv.org/pdf/1711.09268.pdf) + +Code adapted from the released TensorFlow graph implementation by original +authors https://github.com/brain-research/l2hmc. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow.contrib.eager as tfe + + +class GenericNet(tf.keras.Model): + """Generic neural net with different initialization scale based on input. + + Args: + x_dim: dimensionality of observed data + factor: factor of variance scaling initializer + n_hidden: number of hidden units + """ + + def __init__(self, x_dim, factor, n_hidden=10): + super(GenericNet, self).__init__() + + self.v_layer = _custom_dense(n_hidden, 1. / 3.) + self.x_layer = _custom_dense(n_hidden, factor / 3.) + self.t_layer = _custom_dense(n_hidden, 1. / 3.) + self.h_layer = _custom_dense(n_hidden) + + # Scale + self.scale_layer = _custom_dense(x_dim, .001) + self.coeff_scale = tfe.Variable( + initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True) + # Translation + self.translation_layer = _custom_dense(x_dim, factor=.001) + # Transformation + self.transformation_layer = _custom_dense(x_dim, .001) + self.coeff_transformation = tfe.Variable( + initial_value=tf.zeros([1, x_dim]), + name='coeff_transformation', + trainable=True) + + def call(self, inputs): + v, x, t = inputs + h = self.v_layer(v) + self.x_layer(x) + self.t_layer(t) + h = tf.nn.relu(h) + h = self.h_layer(h) + h = tf.nn.relu(h) + scale = tf.nn.tanh(self.scale_layer(h)) * tf.exp(self.coeff_scale) + translation = self.translation_layer(h) + transformation = ( + tf.nn.tanh(self.transformation_layer(h)) * tf.exp( + self.coeff_transformation)) + + return scale, translation, transformation + + +def _custom_dense(units, factor=1.): + """Custom dense layer with specified weight initialization.""" + + return tf.keras.layers.Dense( + units=units, + use_bias=True, + kernel_initializer=tf.contrib.layers.variance_scaling_initializer( + factor=factor * 2., mode='FAN_IN', uniform=False), + bias_initializer=tf.constant_initializer(0., dtype=tf.float32)) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 4e1380afb2e6e722de65c691d4fbf44621072e87..099b712fc06d1d3eb9ab4095f8db7283690bda76 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -75,7 +75,6 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): mse = lambda xs, ys: mean_square_loss(model, xs, ys) loss_and_grads = tfe.implicit_value_and_gradients(mse) - tf.train.get_or_create_global_step() if logdir: # Support for TensorBoard summaries. Once training has started, use: # tensorboard --logdir= @@ -87,12 +86,13 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): if verbose: print("Iteration %d: loss = %s" % (i, loss.numpy())) - optimizer.apply_gradients(grads, global_step=tf.train.get_global_step()) + optimizer.apply_gradients(grads) if logdir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar("loss", loss) + tf.contrib.summary.scalar("loss", loss, step=i) + tf.contrib.summary.scalar("step", i, step=i) def synthetic_dataset(w, b, noise_level, batch_size, num_batches): @@ -119,7 +119,7 @@ def synthetic_dataset_helper(w, b, num_features, noise_level, batch_size, def main(_): - tfe.enable_eager_execution() + tf.enable_eager_execution() # Ground-truth constants. true_w = [[-2.0], [4.0], [1.0]] true_b = [0.5] diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py index e53234b51a7dccc11e548ac81a7ef070c628aa52..2bc2fc2aa9150a3181db612439d0c37c8e76d1e3 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py @@ -117,5 +117,5 @@ class EagerLinearRegressionBenchmark(tf.test.Benchmark): if __name__ == "__main__": - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb index 9fd2d8d1254e32ae75ab5b085986c6e1c05e76f4..51d10a778413cfbb574b4e22e8adcb18bd731dee 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb @@ -1,495 +1,429 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Eager Execution Tutorial: Basics", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [ - { - "file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg", - "timestamp": 1504118841551 - } - ] - } - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "U9i2Dsh-ziXr", - "colab_type": "text" + "colab_type": "text", + "id": "U9i2Dsh-ziXr" }, - "cell_type": "markdown", "source": [ - "# Eager Execution Tutorial: Basics\n", + "# An introduction to TensorFlow\n", "\n", - "This notebook introduces the basics of using TensorFlow's eager execution capabilities. It covers concepts such as:\n", + "This is an introductory tutorial for using TensorFlow. It will cover:\n", "\n", "* Importing required packages\n", - "* Enabling eager execution\n", - "* Creating and using TensorFlow Tensors and Variables\n", - "* Using TensorFlow interactively\n", - "* Using GPUs with eager execution enabled\n", - "\n", - "This notebook does *not* cover modeling topics, such as gradients." + "* Creating and using Tensors\n", + "* Using GPU acceleration\n" ] }, { + "cell_type": "markdown", "metadata": { - "id": "z1JcS5iBXMRO", - "colab_type": "text" + "colab_type": "text", + "id": "z1JcS5iBXMRO" }, - "cell_type": "markdown", "source": [ - "# Step 1: Import Eager\n", + "## Import TensorFlow\n", "\n", - "The key imports for eager execution are the following:" + "To get started, import the `tensorflow` module and enable eager execution.\n", + "Eager execution enables a more interactive frontend to TensorFlow, the details of which we will discuss much later." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "RlIWhyeLoYnG", - "colab_type": "code", + "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, - "cellView": "code" + "colab_type": "code", + "id": "RlIWhyeLoYnG" }, - "cell_type": "code", + "outputs": [], "source": [ - "# Import TensorFlow.\n", "import tensorflow as tf\n", "\n", - "# Import TensorFlow eager execution support (subject to future changes).\n", - "tfe = tf.contrib.eager" - ], - "execution_count": 0, - "outputs": [] + "tf.enable_eager_execution()" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "H9UySOPLXdaw", - "colab_type": "text" + "colab_type": "text", + "id": "H9UySOPLXdaw" }, - "cell_type": "markdown", "source": [ - "# Step 2: Enable eager execution\n", + "## Tensors\n", "\n", - "All future TensorFlow calls will execute the\n", - "underlying TensorFlow ops immediately:" + "A Tensor is a multi-dimensional array. Similar to NumPy `ndarray` objects, `Tensor` objects have a data type and a shape. Additionally, Tensors can reside in accelerator (like GPU) memory. TensorFlow offers a rich library of operations ([tf.add](https://www.tensorflow.org/api_docs/python/tf/add), [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul), [tf.linalg.inv](https://www.tensorflow.org/api_docs/python/tf/linalg/inv) etc.) that consume and produce Tensors. These operations automatically convert native Python types. For example:\n" ] }, { - "metadata": { - "id": "WPTUfGq6kJ5w", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, "cell_type": "code", - "source": [ - "tf.enable_eager_execution()" - ], "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "twBfWd5xyu_d", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Step 3: Interactively Use TensorFlow!\n", - "\n", - "Now you can call TensorFlow functions and get results, immediately! No more `tf.Sessions`!\n", - "\n", - "TensorFlow will automatically wrap native Python types for you with operator overloading for TensorFlow Tensors." - ] - }, - { "metadata": { - "id": "ngUe237Wt48W", - "colab_type": "code", + "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 125 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 320, + "status": "ok", + "timestamp": 1526420535530, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 }, - "cellView": "code" + "id": "ngUe237Wt48W", + "outputId": "b1a1cd60-4eb3-443d-cd6b-68406390784e" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(3, shape=(), dtype=int32)\n", + "tf.Tensor([4 6], shape=(2,), dtype=int32)\n", + "tf.Tensor(25, shape=(), dtype=int32)\n", + "tf.Tensor(6, shape=(), dtype=int32)\n", + "tf.Tensor(aGVsbG8gd29ybGQ, shape=(), dtype=string)\n", + "tf.Tensor(13, shape=(), dtype=int32)\n" + ] + } + ], "source": [ "print(tf.add(1, 2))\n", "print(tf.add([1, 2], [3, 4]))\n", "print(tf.square(5))\n", "print(tf.reduce_sum([1, 2, 3]))\n", "print(tf.encode_base64(\"hello world\"))\n", - "print(\"\")\n", "\n", - "x = tf.constant(2)\n", - "y = tf.constant(3)\n", - "print(x * y + 1)\n", - "\n", - "# Most TensorFlow ops are directly usable with eager execution, giving\n", - "# results immediately.\n", - "print(tf.contrib.signal.hamming_window(x * y + 1))" - ], - "execution_count": 0, - "outputs": [] + "# Operator overloading is also supported\n", + "print(tf.square(2) + tf.square(3))" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "IDY4WsYRhP81", - "colab_type": "text" + "colab_type": "text", + "id": "IDY4WsYRhP81" }, - "cell_type": "markdown", "source": [ - "Numpy arrays are supported, too:" + "Each Tensor has a shape and a datatype" ] }, { - "metadata": { - "id": "lCUWzso6mbqR", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, "cell_type": "code", - "source": [ - "import numpy as np\n", - "\n", - "ones = np.ones([3, 3])\n", - "\n", - "print(\"numpy 3x3 matrix of 1s:\")\n", - "print(ones)\n", - "print(\"\")\n", - "\n", - "print(\"Multiplied by 42:\")\n", - "print(tf.multiply(ones, 42))" - ], "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "PBNP8yTRfu_X", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Step 4: Define and Print TensorFlow Variables\n", - "\n", - "To define TensorFlow variables, use the `get_variable()` function as follows:" - ] - }, - { "metadata": { - "id": "3Twf_Rw-gQFM", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 53 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 215, + "status": "ok", + "timestamp": 1526420538162, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 }, - "cellView": "code" + "id": "srYWH1MdJNG7", + "outputId": "5e4ac41c-5115-4e50-eba0-42e249c16561" }, - "cell_type": "code", - "source": [ - "x = tfe.Variable(0.)" + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 2)\n", + "\u003cdtype: 'int32'\u003e\n" + ] + } ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "45G7094TxsMb", - "colab_type": "text" - }, - "cell_type": "markdown", "source": [ - "## Printing TensorFlow Variables" + "x = tf.matmul([[1]], [[2, 3]])\n", + "print(x.shape)\n", + "print(x.dtype)" ] }, { + "cell_type": "markdown", "metadata": { - "id": "UJBJeZ5XxuwA", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" + "colab_type": "text", + "id": "eBPw8e8vrsom" }, - "cell_type": "code", "source": [ - "# This does NOT print the Variable's actual value:\n", - "print(\"Printing a TensorFlow Variable:\")\n", - "print(x)\n", - "print(\"\")\n", + "The most obvious differences between NumPy arrays and TensorFlow Tensors are:\n", "\n", - "\n", - "print(\"Printing a TensorFlow Variable's value as a numpy array:\")\n", - "print(x.numpy())" - ], - "execution_count": 0, - "outputs": [] + "1. Tensors can be backed by accelerator memory (like GPU, TPU).\n", + "2. Tensors are immutable." + ] }, { + "cell_type": "markdown", "metadata": { - "id": "2njjWHcTpBEn", - "colab_type": "text" + "colab_type": "text", + "id": "Dwi1tdW3JBw6" }, - "cell_type": "markdown", "source": [ - "## Changing a TensorFlow Variable's value\n", + "### NumPy Compatibility\n", "\n", - "To change a TensorFlow Variable's value, use its `.assign()` or `.assign_add()` method:" + "Conversion between TensorFlow Tensors and NumPy ndarrays is quite simple as:\n", + "* TensorFlow operations automatically convert NumPy ndarrays to Tensors.\n", + "* NumPy operations automatically convert Tensors to NumPy ndarrays.\n", + "\n", + "Tensors can be explicitly converted to NumPy ndarrays by invoking the `.numpy()` method on them.\n", + "These conversions are typically cheap as the array and Tensor share the underlying memory representation if possible. However, sharing the underlying representation isn't always possible since the Tensor may be hosted in GPU memory while NumPy arrays are always backed by host memory, and the conversion will thus involve a copy from GPU to host memory." ] }, { - "metadata": { - "id": "v3wr6Erbo_hB", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, "cell_type": "code", - "source": [ - "x.assign(42)\n", - "print(x)\n", - "\n", - "x.assign_add(3)\n", - "print(x)" - ], "execution_count": 0, - "outputs": [] - }, - { "metadata": { - "id": "uhtynjHVpTB5", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Use a Variable just like any other Tensor" - ] - }, - { - "metadata": { - "id": "7PbktdnHoehR", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } - } + }, + "height": 251 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 238, + "status": "ok", + "timestamp": 1526420540562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "lCUWzso6mbqR", + "outputId": "fd0a22bc-8249-49dd-fcbd-63161cc47e46" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorFlow operations convert numpy arrays to Tensors automatically\n", + "tf.Tensor(\n", + "[[ 42. 42. 42.]\n", + " [ 42. 42. 42.]\n", + " [ 42. 42. 42.]], shape=(3, 3), dtype=float64)\n", + "And NumPy operations convert Tensors to numpy arrays automatically\n", + "[[ 43. 43. 43.]\n", + " [ 43. 43. 43.]\n", + " [ 43. 43. 43.]]\n", + "The .numpy() method explicitly converts a Tensor to a numpy array\n", + "[[ 42. 42. 42.]\n", + " [ 42. 42. 42.]\n", + " [ 42. 42. 42.]]\n" + ] + } + ], "source": [ - "print(x + 3)\n", + "import numpy as np\n", "\n", - "# This code will broadcast the value across the list of numbers:\n", - "print(x * [1, 2, 4])" - ], - "execution_count": 0, - "outputs": [] + "ndarray = np.ones([3, 3])\n", + "\n", + "print(\"TensorFlow operations convert numpy arrays to Tensors automatically\")\n", + "tensor = tf.multiply(ndarray, 42)\n", + "print(tensor)\n", + "\n", + "\n", + "print(\"And NumPy operations convert Tensors to numpy arrays automatically\")\n", + "print(np.add(tensor, 1))\n", + "\n", + "print(\"The .numpy() method explicitly converts a Tensor to a numpy array\")\n", + "print(tensor.numpy())" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "GVChqwlwy1SI", - "colab_type": "text" + "colab_type": "text", + "id": "PBNP8yTRfu_X" }, - "cell_type": "markdown", "source": [ - "# Step 5: Debug Errors with Instant Feedback\n", + "## GPU acceleration\n", "\n", - "TensorFlow's eager execution helps you identify and debug runtime issues through interactive exploration of code snippets.\n", - "\n", - "Below, we'll define a length-4 vector, and attempt two `tf.slice()` operations,\n", - "one being legal and the other being illegal, leading to a runtime error that is\n", - "raised immediately." + "Many TensorFlow operations can be accelerated by using the GPU for computation. Without any annotations, TensorFlow automatically decides whether to use the GPU or CPU for an operation (and copies the tensor between CPU and GPU memory if necessary). Tensors produced by an operation are typically backed by the memory of the device on which the operation executed. For example:" ] }, { - "metadata": { - "id": "23ap04N0v4k0", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, "cell_type": "code", - "source": [ - "vector = tf.constant([10.0, 20.0, 30.0, 40.0])" - ], "execution_count": 0, - "outputs": [] - }, - { "metadata": { - "id": "FCUMsIYxxRRa", - "colab_type": "code", + "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 53 }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "# Works, because the values of `begin` and `size` (the 2nd and 3rd input\n", - "# arguments) are within the bound of `vector`.\n", - "print(tf.slice(vector, [1], [3]))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "T8me2oCNxpFp", "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + "executionInfo": { + "elapsed": 340, + "status": "ok", + "timestamp": 1526420543562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 }, - "cellView": "code" + "id": "3Twf_Rw-gQFM", + "outputId": "2239ae2b-adf3-4895-b1f3-464cf5361d1b" }, - "cell_type": "code", - "source": [ - "# The following does NOT work, because the value of `size` (the 3rd\n", - "# argument) causes the indices to go out of the bounds of `vector`. The\n", - "# error is raised immediately.\n", - "try:\n", - " print(tf.slice(vector, [1], [4]))\n", - "except tf.OpError as e:\n", - " print(\"Caught error: %s\" % e)" + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Is there a GPU available: False\n", + "Is the Tensor on GPU #0: False\n" + ] + } ], - "execution_count": 0, - "outputs": [] + "source": [ + "x = tf.random_uniform([3, 3])\n", + "\n", + "print(\"Is there a GPU available: \"),\n", + "print(tf.test.is_gpu_available())\n", + "\n", + "print(\"Is the Tensor on GPU #0: \"),\n", + "print(x.device.endswith('GPU:0'))" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "irxJhAgar84v", - "colab_type": "text" + "colab_type": "text", + "id": "vpgYzgVXW2Ud" }, - "cell_type": "markdown", "source": [ - "# Step 6: Using the GPU\n", - "\n", - "You can explicitly place Tensors on the GPU by calling a Tensor's `.gpu()` method. The `.device` property tells you whether the Tensor is backed by CPU or GPU memory.\n", + "### Device Names\n", "\n", - "The first operation executing on the GPU may be slow as TensorFlow initializes. Subsequent uses will be much faster." + "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:\u003cN\u003e` if the tensor is placed on the `N`-th tensor on the host." ] }, { + "cell_type": "markdown", "metadata": { - "id": "7J4N9baqaKCL", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "colab_type": "text", + "id": "ZWZQCimzuqyP" }, - "cell_type": "code", "source": [ - "# Create some Tensors\n", - "SIZE = 1000\n", - "tensor = tf.random_normal([SIZE, SIZE])\n", - "print(tensor.device)\n", "\n", "\n", - "if tf.test.is_gpu_available():\n", - " gpu_tensor = tensor.gpu()\n", - " cpu_tensor = tensor.cpu()\n", - "else:\n", - " print(\"GPU not available.\")\n", - " cpu_tensor = tensor" - ], - "execution_count": 0, - "outputs": [] + "### Explicit Device Placement\n", + "\n", + "The term \"placement\" in TensorFlow refers to how individual operations are assigned (placed on) a device for execution. As mentioned above, when there is no explicit guidance provided, TensorFlow automatically decides which device to execute an operation, and copies Tensors to that device if needed. However, TensorFlow operations can be explicitly placed on specific devices using the `tf.device` context manager. For example:" + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "4E-2n7VbzY1n", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } - } + }, + "height": 53 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1762, + "status": "ok", + "timestamp": 1526420547562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "RjkNZTuauy-Q", + "outputId": "2e613293-ccac-4db2-b793-8ceb5b5adcfd" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "On CPU:\n", + "10 loops, best of 3: 35.8 ms per loop\n" + ] + } + ], "source": [ - "# Time a CPU-based matrix multiplication\n", + "def time_matmul(x):\n", + " %timeit tf.matmul(x, x)\n", "\n", - "print(\"Time to conduct matmul on CPU:\")\n", - "%time tf.matmul(cpu_tensor, cpu_tensor)" - ], - "execution_count": 0, - "outputs": [] + "# Force execution on CPU\n", + "print(\"On CPU:\")\n", + "with tf.device(\"CPU:0\"):\n", + " x = tf.random_uniform([1000, 1000])\n", + " assert x.device.endswith(\"CPU:0\")\n", + " time_matmul(x)\n", + "\n", + "# Force execution on GPU #0 if available\n", + "if tf.test.is_gpu_available():\n", + " with tf.device(\"GPU:0\"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.\n", + " x = tf.random_uniform([1000, 1000])\n", + " assert x.device.endswith(\"GPU:0\")\n", + " time_matmul(x)" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "vbSFW-T5zhZF", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "colab_type": "text", + "id": "YEOJTNiOvnpQ" }, - "cell_type": "code", "source": [ - "# Time GPU-based matrix multiplications.\n", + "## Next Steps\n", "\n", - "if tf.test.is_gpu_available():\n", - " # First use of the GPU will be slow:\n", - " print(\"Time to conduct first matmul on GPU:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)\n", - " print()\n", - "\n", - " # Subsequent uses are much faster:\n", - " print(\"Time to conduct second matmul on GPU:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)" - ], - "execution_count": 0, - "outputs": [] + "In this tutorial we covered the most fundamental concepts in TensorFlow - `Tensor`s, operations, and devices.\n", + "In [the next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/2_gradients.ipynb) we will cover automatic differentiation - a building block required for training many machine learning models like neural networks." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "TensorFlow: An introduction", + "provenance": [], + "version": "0.3.2", + "views": {} } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb index 1e65b27bc8be8b05fefa38dffae7799b1e503bd3..9c1af9c2084bac7ae6369babeaa13720e6199097 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb @@ -7,12 +7,9 @@ "id": "vDJ4XzMqodTy" }, "source": [ - "# Eager Execution: Working with Gradients\n", + "# Automatic Differentiation\n", "\n", - "This notebook demonstrates:\n", - "\n", - "* How to get gradients using TensorFlow's eager execution capabilities\n", - "* How to apply the gradients so you can update your variables" + "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." ] }, { @@ -22,7 +19,7 @@ "id": "GQJysDM__Qb0" }, "source": [ - "# Setup: Import eager and enable eager execution.\n" + "## Setup\n" ] }, { @@ -40,12 +37,10 @@ }, "outputs": [], "source": [ - "# Import TensorFlow.\n", "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", "\n", - "\n", - "# Enable eager execution.\n", - "tf.enable_eager_execution()" + "tfe = tf.contrib.eager # Shorthand for some symbols" ] }, { @@ -55,28 +50,15 @@ "id": "1CLWJl0QliB0" }, "source": [ - "# Fitting a Simple Linear Model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-39gouo7mtgu" - }, - "source": [ - "## Step 1: Synthesize some data\n", - "\n", - "To demonstrate fitting a model with TensorFlow's eager execution, we'll fit a linear model to some synthesized data (which includes some noise).\n", + "## Derivatives of a function\n", "\n", - "In the code, we use the variable names `w` and `b` to represent the single weight and bias we'll use to fit our model." + "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: " ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -84,105 +66,53 @@ } }, "colab_type": "code", - "id": "rQsdCg9PfIL-" + "id": "9FViq92UX7P8" }, "outputs": [], "source": [ - "# The constants we'll try to fit our variables to:\n", - "true_w = 3\n", - "true_b = 2\n", - "\n", - "NUM_EXAMPLES = 1000\n", + "from math import pi\n", "\n", - "# Our inputs:\n", - "inputs = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", + "def f(x):\n", + " return tf.square(tf.sin(x))\n", "\n", - "# Our labels, with noise:\n", - "noise = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", - "labels = inputs * true_w + true_b + noise" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 347 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 374, - "status": "ok", - "timestamp": 1525154227149, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "O4lsC4ckAcar", - "outputId": "f8becb3f-498b-4cb7-9ef3-608a68cb65d0" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAecAAAFKCAYAAAAnj5dkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xt8VPWdP/7X3M5MkpkkM8mEAAER\nQoICgUBALkUEQ7FucekDEeWL3VZXu121dler39pu1Vbb77b+2m1/3277qNXa2kUptGttt/tDEWqp\nyDWBiC6ES8slXDJJJpfJ3C+/P8JM5nLOmTOTmWQm83r+RebMnJyTAO/z+Xzen/dbFQqFQiAiIqKc\noR7rCyAiIqJYDM5EREQ5hsGZiIgoxzA4ExER5RgGZyIiohzD4ExERJRjtGN9AWE220DWzm02F8Nu\nd2bt/LmukO+/kO8d4P0X8v0X8r0D+XH/VqtJ8lhBjJy1Ws1YX8KYKuT7L+R7B3j/hXz/hXzvQP7f\nf0EEZyIionzC4ExERJRjGJyJiIhyDIMzERFRjmFwJiIiyjEMzkRERDmGwZmIiCjHMDgTERHlGAZn\nIiKiJDy+ADrtTnh8gVH5fjlTvpOIiCjXBIJBbNt9Gq3tNvT0e2Ap1aOxzopNq2uhUWdvfMvgTERE\nJGHb7tPYdfhi5Ovufk/k683NdVn7vpzWJiIiEuHxBdDabhM91treldUpbgZnIiIiEX0OD3r6PaLH\n7ANu9DnEj2UCgzMREZGIMqMellK96DGzyYAyo/ixTGBwJiIiEqHXadBYZxU91lhXCb0ue20pmRBG\nREQkYdPqWgBDa8z2ATfMJgMa6yojr2cLgzMREZEEjVqNzc112LByBvocHpQZ9VkdMYcxOBMRESWh\n12lQZS4ete/HNWciIsqa0a6sNV5w5ExERBk3VpW1xgsGZyIiyrixqqw1XvDxhYiIMmosK2uNFwzO\nRESUUWNZWWu8YHAmIqKMGsvKWuMFgzMREWXUWFbWGi+YEEZERBk3VpW1xgsGZyIiyrixqqw1XjA4\nExFR1ox2Za3xgmvORESUMawIlhmKRs7t7e34x3/8R3zmM5/Bli1bcPnyZTzxxBMIBAKwWq34zne+\nA0EQYj7zzW9+E8eOHYNKpcJTTz2FhoaGrNwAERGNPVYEy6ykPzGn04lvfOMbWLp0aeS1H/zgB9i8\neTO2bt2K6667Djt27Ij5zMGDB3Hu3Dls27YNzz//PJ5//vnMXzkREeWMcEWw7n4PQhiuCLZt9+mx\nvrS8lDQ4C4KAF198EVVVVZHXDhw4gFtvvRUAsGrVKrz//vsxn3n//ffR3NwMAJgxYwb6+vrgcDgy\ned1ERJQjlFQE43R3apJOa2u1Wmi1sW9zuVyRaeyKigrYbLG/lK6uLsyePTvytcVigc1mg9FozMQ1\nExFRCjy+QFYzppNVBHt150mcPG/ndHcKRpytHQqFMvIes7kYWm320uytVlPWzp0PCvn+C/neAd5/\nId+/xVKCl3/3IfYfvwxbrwvW8iIsmTMR962bDY0mc4HRVFYEq7kInXZXwjG9oMW+41ciX4enu4uL\nBDywfm7GrkFMPv/u0wrOxcXFcLvdMBgMuHr1asyUNwBUVVWhq6sr8nVnZyesVvFqMWF2uzOdS1HE\najXBZhvI2vlzXSHffyHfO8D7L+T7t1pN+L+/ao3pDNVpd+HNvWfhdHkz3hmqYUZFzPcKC4WCou9/\n79glfGLxlKztfc6H373cw0Naj07Lli3Dzp07AQBvvfUWVqxYEXN8+fLlkeMffvghqqqqOKVNRDSK\n3F7/qHaG2rS6Fs1NNagoNUCtAipKDVg+pxpur3hwZgMMeUlHzsePH8e//uu/oqOjA1qtFjt37sQL\nL7yA//2//ze2bduGSZMmYf369QCAf/qnf8K3vvUtLFiwALNnz8bdd98NlUqFp59+Ous3QkREw+z9\nyTtDZaI4SPR6dnxFMAA4cd6ObpHrYAMMeUmD85w5c/Dqq68mvP6zn/0s4bXvfe97kT8//vjjI7w0\nIiJKl7l0qDNUssCYLFlM6rjcvubooN9YZxWd7mYDDHks30lENA54fAHYel1AKASruRhWQSsbGLUa\nFbbuapcsGpKsqEh4X3NYONELQMx6tlgDjIYZFqxqnAyPL8AALYHBmYgojwWCQbz+zim898EVuL1D\n68gGQY3mxdfhzlumAxDvDJUsuMod37Byhsx6tg0bVs6IBN3oBhg9/W7sOnIRbae78MfWS9xWJYPB\nmYgoj23bfRrvHOmIec3tDeL3f/4L3G6faGeoZEVD1i2bJnv85nmTJNezu/s9eHXnSXz29lkxAVev\n02BPawf2tHTEvFdstE1sfEFElLfkgiwAtJy0RaaOq8zFkdFssqIhFzsdsscRCsFSKp3Mte/4lYSy\nnUqqiNEwBmciojE0krKWckEWAOwDHtHtSmVGvWRwNZsMqKkySh4XdBpYyorQWCdfuyI+4CZ7IOC2\nqlic1iYiGgPpdHGKz5wOB1mxjGwAMJv0otuV9DqNbLKYqViQPO72BvDG3rPYtLoWLrcf70VV/4oW\nv11L7lq5rSoRgzMR0RhQmu0MyAdyqSAKAAvqrZLZ0GJZ1OFkMQBYv+J6/LntciTJLFprexc2rJyB\nLWvr8T/netAz4E14T3zATfZAwKztWAzORESjLNn6a3S2MyAfyDetrkUoFIrL1tagefFU/O2y6ySv\nITqLWmwfs8Ppg0ckMAOxo+IF9VWKA26yBwIaxuBMRDTKlKy/hqeDlQTy/7WmHnfeUhuzz7lmUrmi\n2tLhZLF4SqehUwm4yR4IaBiDMxHRKEtl/VVpINfrNKixKu9hkKwymNJp6HQCrtQDAQ1jcCYiGmWp\nrL9mOpHK6fFh69uncOJcD+wDXlhK9WiorUTzwhpYSg0x3zuVUTEDbmYxOBMRjQGlgS9TiVThpLL4\nJK/ufg/2tAwVB6mIyxgXGxUDQHefO/JnTk9nB4MzEdEYSGU6eP2K6+F0+3HinB29Do9oIE82TR2f\nVCZGKmNcr9OgoswQyRjv7vfAIKgBqODxBliGMwsYnImIxlB4v7LSzk9LZ1fjnjV1KNZrJd/TWGfF\nw3c1Rs6TrJJYPCUZ49F9mlmGM/MYnImIRlH0CFerUaXc+em941fg8vrxd7fNgqlYkNxmVVwkYP3y\naQCSVxKLl0rGeDSxoE7pYXAmIsqAZNPKYiPcYoMOFzodkfco7fzU0t6F1vY/Y7K1BE63T/Q9+49f\nxicWT1FUSSye2aSH1xeI1OVWGtzjgzqlj8GZiGgElJbhFBvhSgXLZJ2fACAE4KJtUPJ4V68LZzv6\nMH1ymWxSmZhBtw9Pv3woci/rV0xXFNxZhjNzGJyJiEZASRnOVNd8ozs/KR3tJlAB33n9aCQDO7G3\nsx71U83QaVU4ftYO+4Abgk4DtzcQWU+OvhclwZ1lODOHwZmIKE1Ky3CmuuZbbtQDKhUaaitj+h+n\nIngtXyv+YUEsO9zjC8DW68K//eqoZC3tZ+9fhEAgiNZTXeh1eGEQhj7r9QVYhjMLGJyJiNKktHpX\nqmu+To8fT790EGaTgImWYlzucY74WsMPC2L0Og0ErRp2kQYWwNC9bH37FE6et6PP4YXZqMf8ukps\nWDkdDqeP+5yzgMGZiChNSqt3KV3zFbQqeP2hyOh1qNuTF3qdGh5fUPazydgH3Hh150mcPG8XXRuX\nuxdBp8G+qNaQdsdQ4RKNWsWtU1nC3eJERCny+ALotA+NZhvrrKLviV9/3bS6FsvmVMue1+sPib6u\nUqV5oVEEnRr7jl9Bd78HIQxPd7/+zikAww8Q4sSvq7W9Cx6feOcqGhmOnImIFBLLzJ4/sxKrF07G\nsVPdsmU4NWo17l1bj5Pn7Sknebm9QSyfU40j7TbRNWElfH7xkfd7H1zBnbfUQq/TYNPq2si6cp/D\nC0upAbOmluO9qFFzNG6dyh4GZyIihcQys9850oHmpho898BNSctwprqlKUytGirh+T/netIKztWW\nIlzpcYkec3uHksEmVhRj2+7TaDvTjT6HF+VGPRpqK7Bh5QyckHig4Nap7OG0NhGNC+Gp5mxNsybL\nzAYQad0oJRAMIhQKRTKdlQqGgE67SzJhKxmXxy//hlAo8uARnvYOryu/sfes4ql7yhyOnIkor8kV\nAckkpZnZcrbtPo13jqS+Ncpi0qOmypj2vue+QR/0WjU8IlPbBmGogpjcg8ez9y+K/DlZ60jKDAZn\nIsprckVAHr1nYca+z0j7KqdaiCRaSZEOGo0K9VPNMVnTSpUbBcybWYl3Wy8lHFs2txouj1/2wcPh\n9CnuoEWZweBMRHkr2VSz25tkOjcFep0GDTMqsEckwCmZ3k21EEm0C50OPP7DffB4AzAIGoRCoZS2\nVjXOrMTmNXXQadRoOWlDz4AHZSU6LKiz4p5bZ8IfkK5GFr8ljMlfoyPt4Lx9+3a8+eabka+PHz+O\n1tbWyNezZ8/GggULIl+/8sor0Gj4pEVEmZNsqtne78nICCQ8dd52phvAUIJWMDQ03bygXnoKPboZ\nRqqFSOKFE8FSTQibaCnG5jV10KjVQ9nYwRCOtneh1+FB25luaDSnsWl1rWSiGteVx0baf283btyI\njRs3AgAOHjyI//7v/445bjQa8eqrr47s6oiIZCSbajaX6jHQJ56lrEQ4uO48dCGmjGbw2rbfeTMr\nRYtwSK2Dz59Zmdaas5jwA4Icg6DGV/6uKdKAY9vu0zH3Eb0EEH7A4LpybsjItPYPf/hDvPDCC5k4\nFRGRYnJbkxrrKmEQtBgQ+Vwq7R27+z1QSxQBaTvdDc+qQMI5tr7dHjP9HQ6CtyyYhBpriWw3KRWk\nSn7EShaYAeBjDZNQrB/6b37A6cXh/+kUfV+4tCfXlXPHiINzW1sbJk6cCKs1NtXe6/XiscceQ0dH\nB9auXYvPfvazsucxm4uh1WbvL4LVasraufNBId9/Id87MP7v/+G7GlFcJGD/8cuw2V0wl+qxZM5E\nPLh+LoDY+w8Egnj5dx8OvbfXBWt5EZbMmYj71s2GRjO8s/TFNz6ICfhSgdA+4IZG0MFaWRI5/0/e\n+ADvHktclwaAAx9ehcsjPy2tJDADgLXcgAX1VXj70PlIk4toRXotHlg/F3pBi5++eRxvHzwHj1d8\nnTr+PmoUXkOuy+e/+yMOzjt27MCnPvWphNefeOIJ3HHHHVCpVNiyZQuampowd+5cyfPY7SMv7C7F\najXBZhN7fi4MhXz/hXzvQOHc/7qlUzEw6MFRXxfs/R4cOH4ZXq8fD9/ViJ6e4VHq1l3tMUG30+7C\nm3vPwunyxrR3fO+Ysqlns8mAgNcX+RnHnz9essAMDK1jz5tZibbT3TFtHOPNq63E8jnVeOvAedHz\neLx+/OWCHbuOXExa9CT+PsaDfPi7L/fwMOLgfODAAXz1q19NeP2ee+6J/HnJkiVob2+XDc5EROmS\nWkstLhKwfvk0ANlp7xidLDWSrVLRFtRbsbm5Dp5VQ1PvxmIBb+w9G7MWPG9mBUKhEP7tV0clR9qV\n5UUo0mvRclJ8KlvqPig3jKhC2NWrV1FSUgJBEGJeP3v2LB577DGEQiH4/X60tLRg5syZI7pQIiIx\nckFx//HLkYphSoqIAMNJZmLUqqEmFBWlBjQ31WDT6tpIZTKb3Zk0qMtVBrOY9JFzAsPblor1Wmxu\nrsNzD9yEbz64BF/7TBM8ngDeOdJxrWuVuCVzJg7tX05SVWzZnGomfeWgEY2cbTYbLBZL5Ouf/OQn\nWLRoERobG1FdXY0777wTarUaq1evRkNDw4gvlojGv2TJWvHkgq7N7sLZjj5Mn1yWkfaOK+dPwtrF\nU1Fm1EOrUSVkZAs6FTw+8bGsXqfGktlV+GPr5YRjy+dUY8vaetn71WpU2HXkIlpOdsoG3IprmeH3\nrZuNy1f7YTEJku+3lOpx79r6SDY35Y4RBec5c+bgpz/9aeTrBx98MPLnL33pSyM5NREVGLkynHLB\nQy7oqtTAC68fjZxr3sxK7BbZyiTW3hEQ31YUvpb49eVk+5c9viBuXVgDrUYje14p8ZXQxKgAPHpn\nA2qqTNBo1NDrNFhQXyX5uQV1Vk5n5yhWCCOinCBXhlNsL3GY3Eg3nMUcPtetCyejualGci9v9Kg9\nflsRAHT3uSN/Tmd9efeRDty7dlbS7UrxswceX0DR2rGl1ABrXAWvTatrEQyFsO+DK5HEMoOgwfK5\nnM7OZQzORDTmlCZrSYke6fb0u6GSKNBx9FQ3nnvgpoTgGAgGsXVXu+iovaLMkDCir59qTqsUZ9uZ\nHnh8AckymFKzB6saJyddOwbEE7s0ajW2rKnHxltqYbM7AZUK1vIijphzHIMzEY25kXZ80qjVkZHu\n2Y4+vPD6UdH39Qy4I2vQ0eeTG7UHAsGEgiL7jl+BRF0SWcnuRap4idcfkK0IJuhUWNEwSXYkrNdp\nUFOVv/t+Cw2DMxGNuZF2fArT6zSYPrlMeg0awHdePxpJmtq0uhb+QEhy1L637RK8EoU7lBYLiSZ1\nL0Mj91N496h48ZK2092yFcG8vhBUKhUTu8YR/iaJaMyF143FpLoHV+5c4QAXHpFu231adtTu8QbT\nCsJSpO4lvE9bKgD3ObwoNwriB69pbe+KbBsDALfXj067M+Y1yh8cORNRTshk44XwZ9rOdMPW64IK\n4lPCre1dWLds2oi6RSlRbhTQNKtK9F6UFC+xlBrQUFsRU2glXnjKPLxG3namGza7S3HWO+UWBmci\nygnR68YjbbwQPtfnNhTh4LEOfEdiDdo+4IbL45fM9s6EcqOAZ+9bDFOx+MhXSUWy6IeUd1vFR9jh\nKfN0s94pt/AxiohySjiTOZXAHK7SJTaFayrWoUKi4lc4oG1aXYtbF05GNgaWbq8fb+w9i8vdg6LX\nl6wi2c3zqrGqcTL8gRDuWlWLxTdMEH1vY10lAOktXvHT3pTbOHImorwltfXozlumY8cfz0amdvWC\neNSNXgMOBkOi3Z1SMVTeU4VA1NDW7R3K9t7TeikmES08xSy3T3tSZQk+/Isde49dgV7QAAjB7Q3C\nIKgBqOD1BWKm/7v73CPKeqfcweBMRHlLagr35PleXOh0RF53X8u41qiBwLUAbBA0CIVCCASDQxnb\np7pGdC2L6q24Z00dnvv5Ick9yVJTzGLr7cUGbdw9RCd7Dd3EsjnVuDeq7Gemst5p7DE4E1Fekkuk\n6rA5RF8PRI2M3d6h5hHBENBUZ0WvI3mRDznGEgFeXwB2BcVC4gurxK+3F+m1+Porh5Ke5+T53piv\n5Ubh7DyVXxiciSjjUm1ekY6efrdkhrXcnuB477Z2yGZBp3Ieh9MLs0yjiTD7gBu2XhcErTrmZxRe\nb+9U0OEKGPoZxE9VR2eqd/W6RpT1TmOHwZmIMibd5hXp2HVEOrtarppWvFQCebLzHDphg0bBbQo6\nDb63rRV2hw8Wk4AF9VUxPyO56eloekGTMFUdnal+5q/dWX1AouxhtjYRZUx4Dbi734MQYot9ZJLH\nF0Dbaek14urKsUt6CihIKnN7A7A7fACAngEvdh2+iNfeORXJOgcgWUhFKYOgTTnrnXIHR85ElBEj\nbV6RCrkpbQCwlhlwyebMyPfKJLUKQAgQi99/bO3A0XYb7ANeWEr1mDezErcunIyWk12wO8Tv1Xtt\n+YAZ2OMPR85ElBFKmldkytuHz0seU6uAY6d7Mva9MikoEZiBofaWPQPeyIzD7iMdUKlUeOa+RZKl\nO5mBPX4xOBNRRsgV05ALIh5fABc7B3DR5lBUJMPjC2D/h9K9jTO1hpwLWtu7IOg0aJpVJXqcGdjj\nF6e1iQjAyDOsU93GEwgG8do7p7Dvg8uRfbsGQYPlc6tx960zJRPIbL2umD2/41l4xiGTdccpPzA4\nExW4TGZYpxJEtu0+jd1HYrcwhfceq1QqbG6uE39gCOXe0DiV7PBUhGccMll3nPIDgzNRgctkowSl\nQcTjC6DlpPTUdGu7DYFAEG1nuhMeGKzmYhgEdWS0nQuyNZUeP+MQ3gdN4x/XnIkKWLIMa6WNEuIb\nTyRrXtHn8MgW6uju92BP6yXRLVl6nQY33Sje/GG8ELRqNDfVcNq6gHHkTFTAlGRYy43U0p0SL9Jr\nIWhU8AbEh5xS08QtJ20IBEM4fjY3s7GTUauGZuXNJj2cHr/o2rleq8a3/mEpypmFXdAYnIkKWCqN\nEsTWf1OdEo8O5lKBGZCeJu4Z8GSk1OZYWXzDBKxfcT3KjHr8+t0zoslzK+ZPYmAmBmeiQqYkw1pq\ndLx+xfUpFx2JD+ZSykp06Bv0JbyercSr0WAQNNiyth7F+qH/dpmBTXIYnIkKXLIgITU6drr9KU2J\ny61vR6soNWDuDDP+2Ho54Vi+BmYA+FjDxEhgBpQnz1FhYnAmKnByQUIuoJ44Z1c0JR6eDvf6Aoo6\nLTXWVSKYg9ul0lVRKj8iZgY2iWFwJiIA4kFCLmGs1+HB4hsmoPujqwnH6qeWIxAMYeuu9pjpcL3M\nFihTkRY3Xl+B25dch+d/cXjkN5QDyo0CvvaZJpiKxctvEklJOzgfOHAAjz76KGbOnAkAqKurw7/8\ny79Eju/btw/f/e53odFocPPNN+Ohhx4a+dUS0aiSSxgTdBq0X+wFMLwWrNepgJAK7x+/gpb2zphA\nnKz94YDLjwMfXR1KFvPnzh7mkegf9MLl8TM4U8pGNHJevHgxfvCDH4gee+655/DSSy9hwoQJ2LJl\nC9auXYvaWiY6EOWD6MxsqYQxtzcQ2QoUXgv2+EIAQteOiwdYg6BGKAR4fOLHMxGY9Tq15PnV6qEm\nE+kyCBrF5UPZmILSlZVp7QsXLqCsrAwTJ04EAKxcuRLvv/8+gzNRjhPLzJ4/sxKrF07GsVPdsA+4\nUW6U3qOrRCYre5mNAubXWdF2ujsmmW353Go8+zPxqfFQEBC0Knj9qa1rq1XAyvmToFar8M4RZdu5\nGmZYmORFaRlRcD59+jT+4R/+AX19fXj44YexfPlyAIDNZoPFYom8z2Kx4MKFCyO7UiLKOrHM7HeO\ndKC5qQbPPXDTUGKXP4inXzo4hlc5zHXtAeHhDXOgUathLS+CVqPC1rfbJbddhQD4ZfZYSwmGgOam\nKagyF0GlUqG1vQs9A27ZUt/NTVNS/j5EwAiC87Rp0/Dwww/jE5/4BC5cuIBPf/rTeOuttyAI6a2t\nmM3F0Gqz94RptZqydu58UMj3X8j3Dii/f6fLiz+3JW5fAoDWU134zLo5qJlUDrfXD6u5CJ12VyYv\nMy1ubwB7Wjqwp6UDVnMR5s6ohKBTY0/rJdnPpbsl670Pr+LzG+bh0XsWwu3140r3IL7+0/2w9boT\n3qtWD73/wfVzodGMTaVk/t3P3/tPOzhPmDABt99+OwBg6tSpqKysxNWrVzFlyhRUVVWhq6sr8t6r\nV6+iqkq8H2mY3e5M91KSslpNsNkGsnb+XFfI91/I9w4ov/9AMIiv/fSg5FR1d58bD39nN5pmVWHT\n6lo0zKhQVExkNNnsLuw+nN0Zuv0fXMa6pddFpqpLtGrMq60U/VkEg8Af9v0VXq8/5QYimcC/+7l/\n/3IPD2k/zr355pt46aWXAAxNY3d3d2PChKFi9DU1NXA4HLh48SL8fj/27NkTmfImotwQ3axi665T\nuNwj/4Dc6/BGmk9sWl2L5qYaVJQaoFYNJUllmirjZxy5ngEP+hyxWeebVtdiVeMkqCUuOJUGIkRh\naY+cV69ejccffxzvvPMOfD4fnnnmGfz+97+HyWTCmjVr8Mwzz+Cxxx4DANx+++24/vrrM3bRRJS+\n+KQvs0nAoMuv+PPh0pzRhUuMxTq8sfcvOHyiE70O6W5TqcjFMiRq1VDTjmgatRprF0/FHyWm0pU0\nECGKl3ZwNhqN+PGPfyx5fNGiRdi2bVu6pyeiLIlP+pJr3SgmOthEFy7Z3FyHdcum4emXD2YsQOea\nYAii+5ZTaSBCpAT7ORMVEKX1reWIdasKT4+bigUUG8Zv4cGKUr1ooA03EBETbiBClIrx+6+IiBLI\nleNUSq5bVZFei0td2UvuHGuNdVbJQMsuU5RJDM5EBaTMqIfZJIhOZet1ahiLdOgZ8KC8RI95MysQ\nCoVw7HQ3+hxeWEqTd6sCRhb4c9nK+ZNkAy27TFEmMTgTFRC9ToOSIvHgXGUuxlP3LoxJ8Gptt6HP\n4UW5UY+G2gpsWl0LjVqdkenxfPOJm6ZCo06+EsguU5QJDM5EBcTjC8Dp9okec7p9sNmdsJqL8et3\nz8SMiu0OD/a0dECjVmFzc11GpsfzidmoY1IXjSoGZ6ICIhdUu/s9+NrLh2AxCXB6xPfltrZ3Yf2K\n6fjDgXNQqSBbunI8MRbLT1FHNwrhVDZlAoMzUQGR2/ITJre1yj7gxvM/P5y0YMl4M+jyweMLJARe\nsaS4xjprZPqfKF3820NUQOS2/Cih1agKLjADQK8jsTIYMJwU193vQQhDsw/hKmpEI8HgTDTORO87\nFhNdelOVYo1Mf7odI/KcWCERuaQ4luykkeK0NlEOS2UtU2yKdfm8yVi3NDbLOLzlZ/2K6/HLnSdx\n4KNOxaUyg5lrxZxXxAqJyK3fs2QnjRSDM1EOSmctU2zf8Zt7z8Lp8op2RXpj71+w/6POrN1DPlOr\nhmp7W2QKibBkJ2UTgzNRDhILtOGvxQJtsinWDStnxIz8CnGfcipWzp+EtYunxsxYDDi9uNjpQE2V\nEaZiIbJ+L9YukiU7aaQYnIlyTKqBFkg+xWrrdUHQqiPBps/hkc3YLiRTqoxwuv0JJTfDMxRevx/P\n/6IFHTYHgqGhUfVkqxFf+fQCluykrGFwJsox6axlyk2xCjoN/u1XR2Ef8Eamx29fMhVq1VCXpZEQ\ntGp4/bm/EF1jLcGDd8zGntYOtJ3uTgik/kBIcm3/+V+04EKnI/J1MARc6HTg+V+04Nn7FrNkJ2UF\ngzNRjklnLVNuitXtDcDtHcocDk+PO93+EQdmAPjnTQ048FEn3j16KSPny4abGybi3tvqoVGrce/H\n6+FZlZhkp1FDNHlrwOlFh82R8DoAdNgcGHB6I1PcTP6iTOJWKqIck077wUAwiFAoBIMwfMwgqGEQ\nxP+Jnzhnh8UkiB5LxSt/OIlFCLMMAAAgAElEQVSb508as0phgjb5XjCNJvY94UCqZIR7sdMh+dAR\nDA0dJ8oGBmeiHBS9F1mtAipKDWhuqpFcy9y2+zTeOdIRGSEDgNsbhNsrPuXc6/DghussI77OK3YX\n/s8vWyDoUtwwnQHlJTo01FZCp5X/b2xP66W0i4LUVBmhlrg1tWroOFE2cFqbKAeF9yKvWzYtJkM4\nnscXgM3uTDnzWtBpcM+aOpzvdMSsp6bD4xubNefeQR8On1B231KJdMmYigVMthpFf0aTreK/E6JM\nYHAmyiHhoiPRLRvF9jlH74NOJ+s6FArhavcgBl3SdbTHk5EUBfnKpxdIZmsTZQuDM1EOiC86ohc0\nMVPU8fuc4/dBp8rjC+Ibvzgy4uvOFyMpCiJotXj2vsUJ+5yJsolrzkQ5YOvb7TENFKIDc7TW9i4M\nOL0sIJKiTBQFMRULuGGahYGZRgVHzkRjKBAMYuuuU3j36CVF7+/pd+Nip0NyH3Suq7YUodPuyvq2\nK/W1XtNWcxEaZlSwKAjlHQZnojHi8QXwy50n8d7xK4o/oxc0qKkyJu3JnItunj8RaxdNRZFei1f/\nvxM4cd4OlzeYkWIo8UIAHr97PhbPm4yBPldmT040ChiciUZZeH35yImrsDt8KX46hDf2nsWgO9XP\njb332i7jT0cvJ7yejVF0eYke0yeXwSBoMZD50xNlHdeciUZZOJkr9cA8tHd5T+ulhP3LGjWwsnEi\nKkpztxNSIMmOK71O+X9HS26sgtmokzw+n40nKM8xOBMl4fEF0Gl3wuMLiH6d6rmykcwVCAJqlVqy\nslg+ULJf2iBo0NxUg/s/eSMWzpog+p4pVUZsbp6Z6csjGlWc1iaSEL+9yWwSUFIkwOn2Ke6xHC+b\n3aBaT9owt9aSN80o4pmKdBB0atmfT4lBiw0rZ0CjVsd0hOrpd6PMKKBxZiU2r6lT/PsgylUMzkQS\n4vcS9wx40TMwXLQjWY/leIFgEDsPXchKAhQA9A56sfeY8uSyXNNYXwlBq5Hdv20f8ESKiYSrqLEj\nFI1HIwrO3/72t3HkyBH4/X587nOfw8c//vHIsdWrV6O6uhoazdA/lhdeeAETJohPQxHlmlSmn5WW\nhty2+zT2tHSM6LoErQpef462f7rm5saJOHuxHxdtg4o/o9WocO/H6wEAgWAI77Z2iD7AiBUTYUco\nGo/SDs779+/HqVOnsG3bNtjtdnzqU5+KCc4A8OKLL6KkpGTEF0k02uR6KsdTUhoyU2vNC+qtOH6m\nBw63f8TnyhatSoWnP7sIW99uR+upLvQ5vBB0atk15dJiHfyBELQaFTRqFXRa8fdnopgIUT5IOzgv\nWrQIDQ0NAIDS0lK4XC4EAoHISJkon8n1VI6npDSkXLBXqQBTkYB+p3yda4OggVarzunADADvHb+C\njatm4t61s3DX6qFa4V5fAF97+ZDkZ+wOL/ocHuw6clF0WnsoG30yi4lQwUg7OGs0GhQXD40UduzY\ngZtvvjkhMD/99NPo6OjAwoUL8dhjj0Glkm4rZzYXQ6vNXmC3Wk1ZO3c+KOT7T/fel8+bjDf3nlXw\nvkmomVQu+x5TWRGs5qHqWAnXV16EuuvK8WeRPcDRqiuKse+D3F9T9niD8KtUqCwrwmC3EyUmA2pM\nBljLDbD1ukU/Yy0vQs2kcrT96pjo8UAQMOh1qJ5QlvL18O9+4crn+x9xQtiuXbuwY8cOvPzyyzGv\nf+ELX8CKFStQVlaGhx56CDt37sRtt90meR673TnSS5FktZpgsxVuKYJCvv+R3Pu6pVPhdHnR2t4F\n+4Ab5UY9Sop0cLp9sA94YDYZ0FhXiXVLp8JmG4h0lJJKTGqYUSE6Kuwf9CQNzBMtxfjr5fz5Hf7i\nDx/hg9Ndkf3YBkGDynKD5PsbZlTg4qVe0YeXsPfbLmPd0utSmtbm3/3CvHcgP+5f7uFhRMF57969\n+PGPf4yf/vSnMJliv8n69esjf7755pvR3t4uG5yJco1GrcaGlTNwc8NEQKWCtbwIep0mIQgP1cdu\nl2zvGHbnLdNx8nxvpPVgWHxBkXhmowCvP/U91WPp0EedMV+7vQFc7BxEtaUI9gFPZD3ZIGiwfG41\nNq2uhT8QQrlRQK9DfHq/d9CTdttHonyTdnAeGBjAt7/9bbzyyisoLy9POPbFL34RP/rRjyAIAg4d\nOoS1a9eO+GKJRkv8HufogBufHRy/5Sp+i1U4mO88eB4XOh0pX8sN0yx4P4X627nsSo8LZpMejTPL\nsPam61BtKY6MhDVqoHFmJfa0ijcBsYyg7SNRvkk7OP/hD3+A3W7HF7/4xchrN910E+rr67FmzRrc\nfPPN2LRpE/R6PW688UaOmimvSAXcQDAU2fIDAANOLw6f6BQ7BVrbbQgEgmg7042efg9kUi4kLZ9T\njXvWzMTJ8/a8a3QhxT7gwf6POqFRq7FlbX3Msc1r6nC6o1/0IYaZ2lRIVKFQKCc2TWZzbSAf1h6y\nqZDvP5179/gC+OqL+0WDoVoFLLphAjavmYk39v4FR050ot+ZnSYUpmItHr2zAZOtppS7V+ULi0nA\ngvqqmCWAcBvNo+1d6B30wHJtbT+VSmxh/LtfmPcO5Mf9Z23NmSjfhaeci/RauDx+lBn1stuegiHg\nwEdXceCjq0nPrbrWUzhdA04/nvtFCwyCGotvqIJBUCddn843PQPehCprGrUad62qxarGyUAoBKu5\nmCNmKjgMzlSQwmvKLSc70TPgjZTUrCjVY870CpSVCOgdlN93nEym5qTc3iD+dOwKplQZ01qzzgfh\nKmtajUpyrZ/1sqmQMDhTQYpfUw5nT3f3e/DuUfGEpLE2MDg+1pzFhKusxRchSbV+OdF4wUdRKjjZ\natuYbb2D2VnbzgVmkwFFeq3k76W1vSutFp1E+YrBmcalcM9ltzex1GUqdbMzqdwoYGXjJOi16f2z\ny9VlV4Mw8v9GGusq4fL4JX8v4ZE1UaHgtDaNK/H7k63mIjTMqIhZs0ylbnaYkh7J1RVF6OxxiXZT\nUqmAL909H5ayIrSf68XlntQr4uXqwLGi1ICOLvn7MQgaeH0BmE16FBt0GHT50OsYrrIWLkIi9XtR\nUr+caDxhcKZxJX4tudPuSliz1Os0aKyzyvYNjmYx6THnegsOfHQVnmsBWqMeanPo8YWgAhAC4HL7\nJfs0h0LAsz8/BBVUst2Z8pHD5cOqBZPRdrob3f3uoZkBNeDzBSPBd/2K6XA4vZGqamKlTjVqSP5e\nuMeZCg2DM40bcmvJ8T2Xw92NWk7a0DPgiWRriykp0uFPbbG1rwNBYIKlCJe6nAh/rC/JmrDXFwKQ\nE2UFMqp/0Ie1i6bgrlW1kYALICH4FuuH/7uR6sEc/r2E65lHj6yJCgmDM40bcmvJ8T2XNWo1NjfX\nYcPKGZF9zg6XD7sOX0DbmZ5IYGiYYUHbmW7Rc15KMpVbKCylhkgQjg646dTAjv+9SDURIRrvGJxp\n3JBbS5Zas4wOKKZiAfeunRUz5WqzOyVrPcspLdZlrXJYrsnGlLPUyJqoUDBbm8aN8FqymPgAEs7m\njt+eEw7MxmIBv373DL63vS3l67CY9Nh068yUP5ePjEVaTjkTZQFHzjSuxK9ZVpYPZ2sD0t2m7rxl\nOnb88Sxa223o7vcoys6W4nB58dPffZSxe8plTrcfTrcfpmJhrC+FaFxhcKZxJX7Ncsa0Cgz0uSLH\npbpNnTzfG1MaM93APPTZ8Zf0JSUYAi52OnDDNMtYXwrRuMJpbRqXwmuWBmH4+VMum3u81qzONrUK\nqKkyjvVlEI07DM5UMMaqMthYaaqvhEFQlqhVbhSgVg01/qixlkCvG/6vwSBoYCwSn2SbbDVySpso\nCzitTQWjzKiH2SSgZ2Bk3abyhVarwavP3ob/OW2D1+fHd147KloAxSBo8Ox9iyMtM8NFQmx2J6BS\nwVpeBJUqhOd/0YIOmwPB0NCIebLViK98esEY3BnR+MfgTAVDr9Ng1nUW7Dt+ZawvZVS0n+8FANRY\njRhwehGSqrICQNBpYkbAep0GNVWxjeCfvW8xBpxeXOx0oKaKI2aibGJwpoKyec1MtLTb4PbmaKHq\nDOp1eNDV68L2XSfx52OX4Q2IB2ePNxBToEWOqVhg8hfRKOCaM41LUl2p9DoNKssMY3RVo8tsMuB3\ne89i95EO2ezzcIUvIsodHDmTLLEGBbkm+hq1GpVsV6ptu0/jom0w49dQXiKgpEiLrj53zjS2mD3d\njP0fJK9uJlbhKx9+70TjGYMziZIq1hHdenGsiV1jsUEXsy0quivVhpUzJLdSjVTvoBcurz9nArOx\nSIu2093odcgnvy2bUx1T4Ssffu9EhYDBmURJFesAhlsvjjWxa5Tq0dza3oWlN07I6laqXAnMAOBw\n+ZO/CYBOp4r5Oh9+70SFgI/ClCBZ68X4etRjQe4axXT3u/GDX7eNw4aNI/Nu62Vs230aQH783okK\nBYMzJVDSenGspVNQJFm/5ULVctIWWWPO9d87UaFgcKYE4daLYqRaL442uWuk1NgHPJHkr1z/vRMV\nCgZnSpBK68WxIneNwFDVK7UKqCiQbVMjYTbpI1nZuf57JyoUDM4katPqWjQ31aCi1HCt5rIBzU01\nOdW7d/2K6yVrR5cYtHjms4vw/X++BRUcYctaUG+NBN58+L0TFYK0s7W/+c1v4tixY1CpVHjqqafQ\n0NAQObZv3z5897vfhUajwc0334yHHnooIxdLoye+9WIu7nd1OH3wSFT6Cq+dlhn1mDXVjPcKpGQn\nANRPLcPJ831J32cQNFg2N3YrVT783okKQVrB+eDBgzh37hy2bduGM2fO4KmnnsK2bdsix5977jm8\n9NJLmDBhArZs2YK1a9eitpZP3vko3HpxrMgVwwivkYptnwoB+P6ONiw/1Y31N08vmOCsVgGf/cQN\n+MqL+xEQ2dmlF9T40j2NEDRqWM3FkoF3rH/vRIUureD8/vvvo7m5GQAwY8YM9PX1weFwwGg04sKF\nCygrK8PEiRMBACtXrsT777/P4EwpUVIMI7xGGr0vN1p3vwdv7j2LQx8WRmAGhjpFVZmLcUvjZLxz\npCPh+MfmTsT0iWVjcGVElIq0gnNXVxdmz54d+dpiscBms8FoNMJms8FiscQcu3DhQtJzms3F0Gqz\nN31mtZqSv2kcy7f7f/GND0SLYRQXCXhg/dzI6w/f1YjiIgHvf3AJtl636Lku9zizfr1jTa0GplWX\n4juPrIAgaPHIpgUoKdZj//HLsPW6YC0vwpI5E3HfutnQaAor1STf/u5nUiHfO5Df95+RCmGh0MhL\nO9jt2fsP1Go1wWYbyNr5c12u33/81LXT48NbB86Jvve9Y5fwicVTYqZjP7F4CqZXG/Fv29tG65Jz\nyv23z0JDbSVMxQL6+lyR19cvn4Z7b78BZ/7aHfnZ9vRkvq54Lsv1v/vZVMj3DuTH/cs9PKQVnKuq\nqtDV1RX5urOzE1arVfTY1atXUVVVlc63oXFOaura4fZJtnQMF8OoMhcnfL4QVZQa0HTDBMm1Y4Og\n5doxUR5Ka35r+fLl2LlzJwDgww8/RFVVFYxGIwCgpqYGDocDFy9ehN/vx549e7B8+fLMXTGNG+E6\nzt39HoQwPHXderJT8jPlJj28vgA8vkDC5wsR9x8TjU9pjZwXLFiA2bNn4+6774ZKpcLTTz+N3/zm\nNzCZTFizZg2eeeYZPPbYYwCA22+/Hddff31GL5ryn1wdZ49POtQ6nD48/fIhmE0CnJ7Cq/WsUSOS\nhW0Q1AiGQggEg+wYRTTOpL3m/Pjjj8d8PWvWrMifFy1aFLO1igqT3DaodGpjA4DXPxSZegbkWyGO\nFwZBA68vALPJgCK9JqYXtdsbxO4jHVCrVOwYRTTOsGUkZZySbVBye5QNgkZyzXm8U6uG9mhbTAY0\n1lVi/Yrp6HN4sPPQefz52GXRz7S2d2HDyhmc3iYaRxicKeOU9ASW26NcWW5AV687EqD1WjU8/uz3\nSv7nuxrQM+DBG386g95BZf2QM23l/ElYu3hqzGzDG3vP4k9HxQMzEJskR0TjA4MzZZTcWvKREzas\nWzYNpmIBACJlI1vbu2AfcMNsMqDYoMWFTkfsOf1BGAQ13N7EAG0QNCjSa2Af4TS3SgW88t8n0DPg\nRaa3AWvUgFajhscn/4BhEDTYcEstivXD/yyV9K1mxyii8YfBmTJKtieww4OnXz6IpllVkSnu6DrO\nRXotvv7KIYkzq0RftZYXYVJlMQ58JJ3hrUQoNLyOLVb2MhUq1dD5wl2xnti8AIJWjadfPoheh/RD\nhNcXgMPpjQnOStbmmbFNNP4wxZMyKlmf5V6HF7sOX8R/vH0y8lq4jrPL45cMRF5fANWWooTXL3Q6\ncPRUl8gnxk64Jk8wBNh63fiXn+7H7/b9FQvrpVtcAuIjYLmfp1oFrGqcxI5RROMQgzNlVLI+y2F/\nbL2MV986iUBweJhqLNZBL9ECUqdV42qPS/RYsuniseb2BrHr8EWEADQ31Ui2uWyYYUGfwwOPbzgZ\nTu7nubJxMu5dO4vbqIjGIU5rU0Z5fAGsapyMQDCElnYb+mSmcfe0dECjHt4G9Js/nZXM0s71AKzE\nsVPdeO6Bm7B+xfXY+vYpnDhnR6/Dg3KjHiVFOrSd6cYfWy8lZLeLrc031lVyxEw0jjE4U0ZEb5/q\n7vdA0Krg9Sev2xXeBgQA+z6QzkiWky9br6Kzqv/+kzdG9oHvPHQBe1qGO0jFZ7ezxzJR4eF8GGVE\ndClNAIoCMzAcsGx2p2g2thJL50xAjbUkrc+Opvg1Zb1OgzKjHm2nxdfMW9u7Eqa4q2R6MBPR+MHg\nTCOmZLuPlEjAUolnYyvh8wfh8ozNvuRUiGVVy2a3X3twIaLCw+BMafH4Aui0OyNTs+l2hQoHLGt5\nETRq8QCdLGy/13ZFtNKYEjoNUG4U0vqsHL1ODUE7fOUGQYPQtTrY0eSysbl/mahwcc2ZUiJWmrOh\nthJmk6Co3nV8ecropCadVoWAN3E6XC+oMb/Wiv0fXRU950g6UvkCkN17nK4qc3FMMRW3N4B3jnRA\nFVcHW6tRodigE3244P5losLF4EwpESvNuaelA1OqjIqC88caqnHTDdWoqTJGKoUBQ9O7UmvObm8Q\nK+dPwoGPruZ0a0iVauiho6G2AsdOiU/zRyfA9Tk82HnwfEJFNACYUmVkNjZRAWNwJsXk1padbh9W\nNU7C+x9eFc2cVquBSRUl+PAvduw9diVhu5CxWCebdf3j3x5XFJjD1blGm8WkxxfvmoeyEgEXOx0x\n2dfRevrd+OXOkzhx3o6efo/kUrvT7Yc/EMp4KVEiyg8MzqSYfPKSB2sXT8WGW2rx2tvtQ8FnwIOy\nEgGzppZDr9fi3dZLkffHbxd6Y+9fZLdD9Q36FF3jWARmAJhfV4k/HbsUme5Xq4YqhMXT6dR47/iV\nyNdS18tmFkSFjcGZFJNr8xhOXtLrNLj/2h5eW68LCIVQZtRL1sxube/CumXT0HJyZLWxR5v62gjd\nUjq0dh4KhWKm+6WCrldhMRUmgxEVNgZnUkyuzWN08lIgGMSv3z0TGUWWGQXJpCv7gBsXOx2K1qtz\nycrGyVg1fxKgUqGsRJB8+JAaQSfDZDCiwsbgTEmFt0uVGfWKSknGJ43JZUObTQZUmYvSDmIjoVYD\nQZmBrFoFlJYMPViEr89i0mN+XSVUAL6/oy3pw0cIQGmxgH6n/MNH/EicyWBEhY3BmSSJbZsKJ3FJ\nlZJMtSBJY10lXB7/qAdmYCgw11SVoMM2KDoNfcuCydh4S22knaXL40eZUY9fv3tG8cOHqViH/sHk\nswIrGydj7aIpLM1JRAAYnEmG2Lap6CQusWSlPodHtiCI2ahH36AHZpMBc6ab4XT78b3tbZm/eIVc\n7gD+n4eW4Ve7T+PDv3RjwBWA2ShgYVTP6fB9moqFlB8++gd9srMCFpMeC+qHs9aJiAAGZ5IgF4TC\ne3XjR3iBYBA7D12QDEYVpQZ87TNNcLh82HXkIt4/fjntetqZYh9ww+UJwFgsQNBpoXIFoJaoVAbI\nZ6wDQw8f9riSm1KBefmcamxZW8+RMhEl4KM6iVJa8zm6jOe23aexp6VDMhg11lXCVCxgT2sH9rR0\njHlgBoCyEj12HjofadoRwvAMwbbdp2PeGwgGsfPgecm9yRWlBjx17wLJcqBq1VAp0opSA5qbavCZ\n22cxMBORKI6cCUBs0le4W5LctiljsYCtu9qHM7JLBLi80s0nJltLsGl17YiaZGSD3eHBn4+Jt6qM\nnyHYtvs09kTt1Y7XWFeJQDAk2cM6BODxu+dj+uQyBmUiksXgXODkkr6ktk0V6TX4j7dO4v0Ph2td\n9yZJehp0+eDxBbD17VNpN6nIFqmRfnQhELmHCrVqKKFr0+pa+AMhyYcai8nAwExEijA457H40W46\n5JK+Nq2uxcnzvQm1ny/aBnHRNpjS9+lzeLH17VPYF1UdK9eVG/WRQiBy0/yhELB20RRo1Gpo1FC0\nF5yISA6Dcx6SG+2mkvGbLOlr3bJpcLqVlc1Mxlyqx4lzPSM6h0EYCmweXwAqZH9fdEmRLhJM5ab5\nLaWx1byU7AUnIpLD4JyHkm1xUipZ0tfFTkfafZrjOZw+eP3KE8AMggZeXwDma12emhfWwFJqiFz3\nHw6cw5+Oiq8VZ4rTPTQVr9dpFFdHAwCNWo3NzXWSe8GJiJJJKzj7/X585Stfwfnz5xEIBPDEE0+g\nqakp5j2zZ8/GggULIl+/8sor0Gj4H9RIpbPFSUqypK+aKqPkcTla9VBWcnQZaaWBWaUCbmmcjA0r\nZ8Dh9IoGtipzMdYumpr14Gwf8MQ0n0h1RKzXadi4gojSklZw/u1vf4uioiK89tprOHXqFL785S9j\nx44dMe8xGo149dVXM3KRNEzJFielASHZaFDQaVA/1ZzyOnEKA+QEEyxFuPfj9QCAYr30X09LqQEG\nQZ32dixBq4I/EJKdGo9vPpFsRJyJHAAiIiDN4HzHHXfgk5/8JADAYrGgt7c3oxdF0pR0hkqF2Ghw\n/swKBEMhfPXF/ejp90TWeuVaOmaKvd+DAac3UipTPshJFwvR69TwyHSA8vmTL1hLJXDFj4gzlQNA\nRBSWVnDW6XSRP//85z+PBOpoXq8Xjz32GDo6OrB27Vp89rOfTf8qKSKVtU8lxEaDv373DN6JOn84\nKC+bUw29To22Mz2wD7gh6DIftD2+IJ5+6SD6Br2yQa7P4YFH4vuqVEBTfVVM3+R4ZpMeKhVEH3LU\nKmDl/EmKE7gylQNARBSWNDhv374d27dvj3ntkUcewYoVK/Af//Ef+PDDD/HjH/844XNPPPEE7rjj\nDqhUKmzZsgVNTU2YO3eu5Pcxm4uh1WZvKtBqNWXt3KPt4bsaUVwkYP/xy+jqdaGyvAhL5kzEfetm\nQ6MRH6kpuf8aAG6vH21nukWPn+7oww+fWA0AuNLtRCAYwH/vO4cDx6+g15G5vcvhPdPhIFdcJOCB\n9bF/d0xlRbCai9BpdyV83lpehEfubkTFzpN4++A5uDyJQfxj8ycDAN7cezbh2G1Lp+HzG+Ypula5\nn1fbmW58bkMRDMLY5l2Op7/76Sjk+y/kewfy+/6T/q+xceNGbNy4MeH17du3Y/fu3fj3f//3mJF0\n2D333BP585IlS9De3i4bnO12p9JrTpnVaoLNNpC18482jy+AZTdW4dbGSTHTvz094nuPU7n/TrsT\nNpGABwBdvS60n+3CntYOtLbb0i4mkmp7yPeOXcInFk9JmBWYO92Cd450JLx/7nQLnA4P1i+fhs1r\n6/H/vt6KE+ftsA94Iklc65ZOBQA4Xd6EBK9PfWxaxn5eZ/7aPaZJYePt736qCvn+C/negfy4f7mH\nh7Qe6S9cuIDXX38dv/zlL6HXJ65xnj17Fj/84Q/xwgsvIBAIoKWlBbfddls634qiiK1tNsyoQHPT\nFFhKDRlJQkq2pr3r8AXZEpZKTLYaEwqbyOnpF090k4rv/mAQnXYnyox6WIsE3P/JGyWTtUa65SnT\nOQBERECawXn79u3o7e3Fgw8+GHntpZdewiuvvIJFixahsbER1dXVuPPOO6FWq7F69Wo0NDRk7KIL\nldja5p7WS9jTegkVGUpC0us0aKitxJ4WkRHpDEtMyc50LJ9TjXtvq8P2PWfw3gdXIuvVekENny8o\nOqLWC5qEIOfxBXBUYkvZ3qOX8afWy7CU6rF83mSsWzpVdlvTSLY8ZToHgIgIAFShkFib+dGXzemH\nfJjeAOS34nh8AXz1xf1Jp5Kbm2oSkpCU3n94ZN5yshM9A97I9HNFqR6zpprhDwZx4KPOpOdRqYZK\nWsazmPR4/sElkXvz+AKw9boQCASx52iH5L5lg6DB9x75WORzgWAQL/3+I+xXcC2A+M8kk4ZnNBL3\nP491tna+/N3PlkK+/0K+dyA/7j/j09qUWUq24iTrIxyWaiGSaPEj8/Ao1uHy4r3jV2Q2LsWqNhfj\nck9iDsGCemvMdel1GtRYjdi6q122oIj32kNLlbkYgWAQX3/lcErT4iP5mSjBimBElGnchJkDwkFR\nrp9weG0zmehey6mQqzzm8Q1F6WRTLAZBA4OgwZUeZ+TP4f7FqxZMxqrGyfD4YjOnlbSQjF673fp2\ne0qBGUj/Z5Kq8PQ4AzMRjRSD8xhLVo4zHMzCa5vJpJuEpHRkLkavU+OmG6vg9gbg9gYQAiJ/Xjqn\nGg0zLGg73YWvvngAX31xP7buakcgGFT8fcNrtx5fAK2nulK+PiZmEVG+4bT2GEulHGd0Na/ufrfo\nZ9JNQpLLOk5m2Zxqyb2+Le22mCIl8QU65L5vdJ9kYOhn1euQ7xstholZRJRvOHIeY3LT1VK1nZ97\n4CY8/8BNWLVgMipKDVCrhqaOm5tq0m5LqHRkDgwFTVXU92xumiL5gCFVPSw8KyD3fVfOn4R7P14f\nWXcv0mtRbhQUXWNYkXhq+nIAABFXSURBVF6D9Sump/QZIqKxxpHzGEtnK45ep8HEihLc+/F6eFZl\nrtlCfJ1tQacRDa4r50/C2sVTI9/T4wukPOruiZoVSNbtKTphLtWRs8cbgMPplW2iQUSUa/g/Vg5I\ntRVhtEy2JYzOOu7pd+Otw+dx4MOrkc5PBkGDJbOr0Nw0JeFhYNZUs2wt63gqADsPnsfmNXVJs53j\ns8hTUVlexPVmIso7DM45YKRbcTLdqlCv02BPawfebY3d3uT2BrD/w068e63Ax/yZlQgBOHaqC939\nHhgENQAVvL6A5Kg7LBgC9rRegkajjuxBFnvQkEuY0+vUKDFo0evwSn6/JXMmcr2ZiPIOg3MOSXUU\nnK1WhXIBMRwAu/s9CXWtwyPsJbMnoP28XVG3qmR7kOUS5nz+IL5413wIWjWMxQLe2Hs2YfbhvnWz\nJWuOExHlKgbnPJatVoUj2VYFACfP9cKucG04PiM9nnztaj2s5UWRwC42+yDVpYuIKJfxf648pXR/\ntNznO+1O0fcpLXgixZ5CwY9ke5D1Og2KDYldzwCg2KBLGHGzEAgRjQccOecpudFtd78bPf1uTKwo\nSTgWPxVebtRjfl0lNjfPjEyFy2WQK5FKS8hke5A9vgAGXeKj8EGXL7Idi4hoPOHIOU8lG93uOnxB\n9PX4UqF2hwd7Wjrw9VcOR6p2AUMZ5M1NNZF91AZBeQBUEpjVKmDVgslJM9L7HB7YB8SDc6/DMypl\nOYmIRhtHzjkqWQa2XGtHAGg705MwqpSbCr/Q6cDWt9tx79pZAGIzyMOdo/7Udhltp7sjCVfzZ1Zc\ny9Yefq2htgLHTtnQIxFQw0IhYO2iKUkT19gvmYgKEYNzjkklA7t5YY1kcBZLtEqW6PXe8SvYcEtt\npGBHIBjEr989E3MtDTMq0Nw0BZZSQyTwb7wl9kFCo1YlnRK3lIoH1vBDSZFeC5fHjzKjnv2Siajg\nMDjnmFQysC2lBlSkMKosM+pRbtRLJmx5fUG89nY77v/kjZLXEr83GUjcApZODXC5XtLzZ1Zi9cLJ\nMSN0pUVaiIjyEYNzDpGbdj58ohPrlk2DqXi4tnSqpT/1Og3m10lPhQPAifP2SAa3XDa43N7k+Epj\nuw5fQNuZHtnAKtVLOryfurmpBs89cBP7JRNRQWBwziFy0869Di+eefkQFs6KneJOtfTn5uaZOPFX\nOy73OEWP2weGk6yUdsuSEqkBvnaW7Bq6kp7O4QeCTJUqJSLKZQzOOSRZ20a7I3GKO9XSnxq1Gl/5\nu4V47P++B48vmHA8PB0eCIagF9SRql9i70mlbKhc9TMlRU+UPhAQEY0HDM45ROn+4tb2LqxbNi2S\nMKXXaVIq/Vms12HFvEmS0+EAsPXtdtHADADzZlYkJIqNpGyokl7SzMwmokLC4JxjwtPRh090SrZH\n7O534+mXD6LP4U07MIpNh8+bWYFQKISvvrhfMlAaBA2CwRB2tw6vW4+0bKiShxJmZhNRIWFwzjHh\naep1y6bhmZcPSWZWhwN3KoExfho6fjr81++eSTpq93gDOHaqW/RYskQxOeGHhZaTNvQMeGKytcMP\nH0REhYLBOUeZigUsnKW8hGZ8YIwOxIFAEFt3tYtOQ4enw5UkZQFAmVFAr8QDw0jWhePXzqP3OXPE\nTESFhsE5h8VPPZeVSO9RDgfGijJDQhGTMqMeZy/1R94rNtpW2omqcWYl2s50Z61iV/TaefS2MSKi\nQsLgnMPERpNff+WQbGAUKxwitX4cPdpOlpRlMemxoP7a2rbmNCt2ERFlEYNzHogeTcoVHQGkC4eI\niZ6GlkvKWj6nGlvW1kcCb6p7q4mIKDUMznkkEAzCHwhCr1XD4x/a5mQQNFg2txqbVteiu8+taGo6\nLH4aWi7oRmeCp7q3moiIUsPgnMOik7q0GhW+/sphXOh0xLzH7Q3A6fLjctegov3C0eKnoVMNuqns\nrSYiIuXSCs6/+c1v8P3vfx9Tp04FACxbtgyf//znY97z5ptv4uc//znUajXuuusubNy4ceRXWyDE\nOlMZ9Fp02AZF37//o6vY/9FVGAQ1KsuKACQG5ylVRjjdfkXT0Ay6RERjK+2R8+23344nn3xS9JjT\n6cQPf/hD7NixAzqdDnfeeSfWrFmD8vLytC+0kIgldYkF3HhubxAXbYMJgXj5vElYt3Qq/IEQp6GJ\niPJAVqa1jx07hrlz58JkMgEAFixYgJaWFqxevTob3y4vKK1D7fEF0HKyc0Tfa9Dlw9OfXRTZJ1wz\nqRw22wA0anBETESUB9IOzgcPHsT9998Pv9+PJ598EjfeeGPkWFdXFywWS+Rri8UCm00+i9hsLoZW\nm73RnNVqysp53V4/7P0emEv1MAiJP85AIIiXf/ch9h+/DFuvC9byIiyZMxH3rZsNjUad8N4f/Ooo\negbEy3YqZR/woKjEgOnXlURey9b954NCvneA91/I91/I9w7k9/0nDc7bt2/H9u3bY177m7/5Gzzy\nyCO45ZZb0NraiieffBK/+93vJM8RCoWSXojdLt7CMBOsVhNstoGMnlNsXVisxvXWXe0xU9Sddhfe\n3HsWTpc3odzm1l3t2K2wIpgcs0mPgNcXueds3H++KOR7B3j/hXz/hXzvQH7cv9zDQ9LgvHHjRtlk\nrsbGRvT09CAQCECjGRr5VlVVoaurK/Kezs5OzJ8/P5Vrznli68LxVbfkSmKKldtMtkc5vJbc0++G\nTqeGV6TlIwAsqLdyTZmIKI+l3t8PwIsvvojf//73AID29nZYLJZIYAaAefPm4YMPPkB/fz8GBwfR\n0tKCpqamzFxxDkgWdD2+AAD5kpjhAiBhycpnLptTja99pgnPPXATvvW5Jfjuwx/DrQsnwyAM/9wN\nggarF05mMRAiojyX1przunXr8KUvfQmvv/46/H4/nn/+eQDAT37yEyxatAiNjY147LHHcP/990Ol\nUuGhhx6KJIeNB0qCbpW5WHbfcXwBELn3VpTqce/aemjU6pikrv+1ph533lILW68LCIVgvVbpi4iI\n8ltawbm6uhqvvvpqwusPPvhg5M+33XYbbrvttvSvLIcpDbpyJTHjC4DIv1d6mlqv06DGakz3VoiI\nKAexQlgaUgm6qdShZs1qIiICGJzTpjSQplISkzWriYgIYHBOWzbrULN8JhFRYUsrW5uGhQNpvoxw\nPb4AOu3OSEY5ERHlHo6cC4TSoilERDT2GJwLhJKiKURElBvG7ZCJ07fD3F6/oqIpRESUG8bdyFls\n+nb5vMlYt3Rqzk/fKu1clSp7v7KiKURElBvGXXAWm76VajSRK7K9HmwuVV6pjIiIxl5uDyVTpLTm\nda4JP1B093sQwvB68LbdpzNyfoOgRWOdVfRYfNEUIiIae+MqOKfSaCJXjNYDxabVtWhuqkFFqQFq\nFVBRakBzUw2rjxER5aBxNa2dSqOJXKG0icZIsfoYEVH+GFcj53DNazG5On0bfqAQk40HinwrmkJE\nVIjGVXAGxKdv71gxPWenb/PxgYKIiLJrXE1rA+LTtzWTymGzDYz1pUliNyoiIoo27oJzWD41j8jm\nerDHF8DlrkEEfAGOwomI8sS4Dc75KJMPFDF7pwc8sJhYS5uIKF8wOI9TrKVNRJS/OIQah/K1GAsR\nEQ1hcB6H8rEYCxERDWNwHodGe+80ERFlFoPzOMS900RE+Y0JYRmWrbaPqeLeaSKi/MXgnCHZbvuY\nqui90xpBh4DXxxEzEVGe4LR2hmS77WO69DoNJlaWMDATEeURBucM4NYlIiLKJAbnDODWJSIiyqS0\n1px/9KMfYd++fQCAYDCIrq4u7Ny5M3L84sWLWLduHebMmQMAMJvN+MEPfpCBy81N+dhHmoiIclda\nwfnzn/88Pv/5zwMA/vM//xPd3d0J77n++uvx6quvjuzq8kR461J0ucwwbl0iIqJUjShb2+/347XX\nXsMvfvGLTF1P3uLWJSIiypQRBee33noLH/vYx2AwGBKOdXV14Qtf+AI6OzuxefNm3HHHHSP5Vjkv\nm20fiYiosKhCoVBI7g3bt2/H9u3bY1575JFHsGLFCtx///149tlnUVNTE3Pc4XBg586duOOOOzAw\nMICNGzfitddeQ1VVleT38fsD0GoZzIiIiJIGZylOpxMbN27Ef/3XfyV976OPPop77rkHS5YskXyP\nzTaQzmUoYrWasnr+XFfI91/I9w7w/gv5/gv53oH8uH+r1SR5LO2tVCdOnMD06dNFj+3fvx/f+ta3\nAAwF8RMnTuD6669P91sREREVlLSDs81mg8ViiXnt+eefx4ULF9DU1IS+vj5s2rQJn/70p/Hggw9i\nwoQJI75YIiKiQpD2tHamcVo7ewr5/gv53gHefyHffyHfO5Af95+VaW0iIiLKDgZnIiKiHMPgTERE\nlGMYnImIiHJMziSEERER0RCOnImIiHIMgzMREVGOYXAmIiLKMQzOREREOYbBmYiIKMcwOBMREeWY\nggjO3d3d+Pu//3vce++9uPvuu3Hs2LGxvqRR4/f78eSTT+Kee+7BXXfdhcOHD4/1JY26gwcPYunS\npdizZ89YX8qo+uY3v4lNmzbh7rvvRltb21hfzqhrb29Hc3MzfvnLX471pYy6b3/729i0aRM2bNiA\nt956a6wvZ1S5XC48+uij2LJlCzZu3Ji3/+61Y30Bo+HNN9/E3/7t32LdunU4ePAgvv/97+Pll18e\n68saFb/97W9RVFSE1157DadOncKXv/xl7NixY6wva9ScP38eP/vZz7BgwYKxvpRRdfDgQZw7dw7b\ntm3DmTNn8NRTT2Hbtm1jfVmjxul04hvf+AaWLl061pcy6vbv349Tp05h27ZtsNvt+NSnPoWPf/zj\nY31Zo2bPnj3/f3v3D5JaFIAB/BNvRtHfK9ewLVqKIlqaoqJoimgTWguChhqL4g7NRrQooZiDQ2Bo\nBEFDEVE0BOGoREtLiFEXScqSQHhDcHnCe5EP3j3q+X7TuWf6DlzOxz2IB/39/VhYWEA6ncb8/DzG\nx8dFxyqbFOU8NzdnjjOZjFTXV87MzGB6ehoAoKoqXl5eBCeylqZp8Pv90HVddBRLXV9fY3JyEgDQ\n3d2NXC6Ht7c3NDU1CU5mDYfDgVAohFAoJDqK5YaGhjAwMAAAaGlpwcfHB4rFIux2u+Bk1piamjLH\n1bzfS1HOwNf904uLi8jn84hEIqLjWKaurs4cRyIRs6hl0dDQIDqCEIZhoK+vz3xWVRXPz8/SlLOi\nKFAUaba3Ena7HY2NjQCAeDyO0dFRaYr5d7Ozs3h8fEQgEBAd5Z/U3Nsbi8UQi8VK5paXlzEyMoKD\ngwNcXl5ifX29Jo+1v1v73t4eUqlU1b6oP/Hd+mXHf+mVz9nZGeLxeE3udT8RjUZxe3uLlZUVHB0d\nwWaziY5UlporZ4/HA4/HUzJ3c3ODXC6H1tZWjI2NYXV1VVC6/+tPawe+Suv8/Bw7OzslX9K15m/r\nl5HL5YJhGObz09MTNE0TmIisdHV1hUAggN3dXTQ3N4uOY6lkMgmn0wm3243e3l4Ui0Vks1k4nU7R\n0coixa+1T09PcXh4CAC4u7uD2+0WnMg6Dw8PiEaj8Pv9qK+vFx2HLDI8PIyTkxMAQCqVgsvlkuZI\nW3avr6/Y3NxEMBhEW1ub6DiWSyQS5mmBYRh4f39He3u74FTlk+JWqmw2i7W1NeTzeXx+fkLXdQwO\nDoqOZYnt7W0cHx+js7PTnAuHw3A4HAJTWefi4gLhcBj39/dQVRWapklzzLe1tYVEIgGbzYaNjQ30\n9PSIjmSZZDIJr9eLdDoNRVHQ0dEBn88nRVnt7+/D5/Ohq6vLnPN6vSV7QC0rFArQdR2ZTAaFQgFL\nS0uYmJgQHatsUpQzERFRNZHiWJuIiKiasJyJiIgqDMuZiIiowrCciYiIKgzLmYiIqMKwnImIiCoM\ny5mIiKjCsJyJiIgqzC8iivHPF8qqogAAAABJRU5ErkJggg==\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f7a18dfb8d0\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "# Plot the Data (Optional)\n", + "assert f(pi/2).numpy() == 1.0\n", "\n", - "import matplotlib.pyplot as plt\n", "\n", - "plt.scatter(inputs, labels)\n", - "plt.show()" + "# grad_f will return a list of derivatives of f\n", + "# with respect to its arguments. Since f() has a single argument,\n", + "# grad_f will return a list with a single element.\n", + "grad_f = tfe.gradients_function(f)\n", + "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "JaFHyAG9nDET" + "id": "v9fPs8RyopCf" }, "source": [ - "## Step 2: Define our TensorFlow variables\n", + "### Higher-order gradients\n", "\n", - "We'll use Keras's object-oriented [`Dense`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense) layer to create our variables. In this case, we'll create a `Dense` layer with a single weight and bias." + "The same API can be used to differentiate as many times as you like:\n" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "base_uri": "https://localhost:8080/", - "height": 34 + "height": 276 }, "colab_type": "code", "executionInfo": { - "elapsed": 332, + "elapsed": 730, "status": "ok", - "timestamp": 1525154229931, + "timestamp": 1527005655565, "user": { "displayName": "", "photoUrl": "", @@ -190,54 +120,61 @@ }, "user_tz": 420 }, - "id": "z9r-ZeyrXu3A", - "outputId": "e19a698e-5892-4fcd-80d3-1394605ee72c" + "id": "3D0ZvnGYo0rW", + "outputId": "e23f8cc6-6813-4944-f20f-825b8a03c2ff" }, "outputs": [ { "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEDCAYAAAAhsS8XAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXd0HNX5sJ/ZXrTq3ZLV3IvcDdgGGwOm2WCbHhJa6C2B\nUBISQioBfoQPkjhACA4QCIQSDITQbGMbsHHvVbZ6s7q0vc18f4xmJVltJa0q+5zDOXhn9s7dqzvv\nfe/briBJkkSYMGHChBkxqAa7A2HChAkTJrSEBXuYMGHCjDDCgj1MmDBhRhhhwR4mTJgwI4ywYA8T\nJkyYEUZYsIcJEybMCCNkgl0URVasWMHtt98eqibDhAkTJkwvCJlgf+2118jJyQlVc2HChAkTppeE\nRLBXVlayceNGrrjiilA0FyZMmDBh+kBIBPvjjz/OQw89hCAIoWguTJgwYcL0gT4L9g0bNhAfH8/E\niRMJVycIEyZMmMFH6GutmGeeeYYPP/wQtVqN2+3Gbrdz3nnn8dRTT3X6HUmSwtp9CKittvH8UxsQ\nxZY/4aXXTGfa7PRB7NXAU1dj5y9PrIfmYUgeFcnya2aQmBI5uB0bYE5WNPHS/9uE6JcHYukVucw8\nPWOQezXw7NhcyCfvH0Bqfi+uumkO4ycnD3KvBpY+C/bWbNu2jdWrV/PCCy90e291tTVUj+03EhIs\nQ7qfWzfls2tzMTNPH01UrJEv/3eU5LRIVnx/5mB3rUP6azw3fnaMQ7vLOX1RNrVVNvIOVZGeFcPS\nq6YNmT6GmlP7KYoi/3ltF9WVNhacO4btXxfi9fi5+Mpc0jJjhkw/+5t9O0r5Zu1xDEYtpy/KZuOn\nR4mOM3HlTbNRqTo3UAynv3swhOPYhymSJJF3sAqtTs35l05mQm4K6VkxVJY2UVdtH+zuDRgOu4ej\n+yqIjDYwbW4a514yiYTkCMqKGnC7vIPdvQFjz9YSqittjJuSxNTZaVywcgoAX3xwCL9PHOTeDRyH\ndpej0ai47PqZTJyWwoTcFOprHBzdf3KwuzaghFSwz507NyhtPUzfOVnehLXRRdbYeLQ6DQATp6UC\ncGhv+WB2bUA5sLMMv19i2pz0gEaWNS4BUZQoOlE3yL0bGDxuHzu+LsRk1jH/nDEApI6OZtL0VFxO\nLyfLmwa5hwNDU4OT+loHozJiiIw2AjB7QSYajYrtXxXg9foHuYcDR1hjH6bkHawCYOzkxMBnmWPj\nMJq1HDtwEt93YBJ7PT4O7CrDYNQwPrfFhpo1Lh6AgmPVg9W1AaWyrBG/X2JCbjIGozbweXqWbIIp\nLawfrK4NKMX58kI+Oic28FmERc/UOWnYbR7yDn53tPawYB+GiKLI8SNVGEzaNvZTtVrFhKkpuF0+\n8o+OfKGWd7gKt8vHlJmj0GrVgc9j4kxExRopzq/7Tixw5cWNAKSkR7f5PHV0NIIApUXfEcHevEMb\nnR3b5vPxU5IAqChpHPA+DRZhwT4MKS2sx+XwMmZCYjuH0MRpKQAc3lsxGF0bUJQXNWdiYpvPBUEg\ne1w8Pq9IyXdAW60oaUAQ5Gig1uj0GhJTIqkqb8Lj9g1S7wYGn89PWXE90XGmgBlGITrWhN6gobIs\nLNjDDGE6MsMoRMUYSUiOoLKsacQ7zaoqrGh1amLiTO2uZY1LAKDgWM1Ad2tA8Xr9VFVYSUi2oNNr\n2l1Py4xBkqC8pGEQejdwVJQ04vOKZJyirYO80CeNiqSpwYXD5h6E3g08YcE+zJAkiZKCOswWHUmp\nHcdpJyRbEEWJupqRGx3jdvloqHWQmGLpMCciMcWCOUJH0fEaRHHkLnBV5U2IokRKelSH10dlyOaZ\nkW5nD5hhcuI6vJ48Sh6fyrLvhiM5LNiHGQ67B6fDS2JyZKdJXgnJcqxrdeXQj8vtLcpv6ywJSRAE\nMsfF43L6RvTLXF4sa+Kn2tcVkkdFodGoKCvqu8b+zjtv8f3vX8Fvf/ton9sKNUX5tWi0KlLSOl7g\nFDPVSJ4LrWm/dwszpKk5aQMgLimi03u+C4JdCeFLTOk8YSMlLYqDu8qpOWkjtRPBN9wpb/YzpHai\nsas1KlLSoygpqMdhc2OK0Pf6WWvWvMsf//hnkpNTet1Gf9BY76Sxzknm2DjUmo51VXlnBye/I3b2\nsMY+zKitkgV7fGLngj02wYxKLVBdaRuobg04VRWyYO/MHAUQlyCPkTJmIw2/X+RkeRNxCWb0Bm2n\n943KaA577IPW/vTTf6C8vIyHH76ft99+s9ft9AeKUzQto/MMW61OQ1xiBFWV1hHve4Kwxj7sUDT2\n+C40drVaRVyCmdpqG36/iFo9stZvSZKoKrditugwWzrXQKNijahUwojLxH17/XF25VXj8fhx+Hzo\nGh1s/+vmTu8X/SJ2RA59egTDxhMd3jNnQiJXLh7TaRsPPPAztm79lj//+UUiI4dWDZ76Zl9SXBfK\nDshmqZqTNqpPWgM295HKyHrjvwPUVNnQGzRERHa9pU5ItiD6pREn1ADsVjcOu6fbIl9qtYqYeBN1\nNfYRWXlU0Ty7W7hVzddFf181VYlApbUhhDLHYxPMXd6XnCbPl5PfATt7WGMfRng9PhrrnM2JJ11X\nx5Tt7BVUn7QGbO4jhZPliuO0+98VlxBBbZWdpgYnUTHtwyKHI1cuHsNdV83g1b9upuhELdf/cG63\ntvN/v7ydpgYnN99xxoirrFpX48Bk1rXJuu2IlsiYRqYxsiughjX2YURts2bSlRlGocWBOvLsy4p9\nPZiyvLGJshZXWzXydi71tXYMJm1QDtHYeBM+r4itaWTFcXs9PqyNLmLiu1+0IyL1mCN0VJY2jcgd\nXGvCgn0YEbCvd2NLBIiNN6NSCdSMwMiYqoqeaeww8hyoPq9fFmixwe1CouPkBa6+ti8L3NDT9Otq\nHED3ZhiQQ2ATUyNx2D3YbZ7+7tqgEhbsw4hgHKcKao2K2AQztVWyA3WkIIoS1ZVWYuJNHWZankpc\n8wtfO8J8DbLfAKI7yLrtiNhmjba+WRD2hnfe+YDIyKHldAzY1+O7F+xAIEu5sa734zAcCAv2YURt\nlQ2VWgj6ZU5ItuD3S4GogZFAU4MTr8dPQlJwfgNThA6DUTPinMi11fIi31E5hY5Q5kx97cgSaMrc\nDkZjB7luDEBDnbPf+jQUCAv2YYIoitRW24mNNwcdvjgS7eyN9fILGR1r7OZOGUEQiE2IoLFeXhBG\nCjXNpqXoYE0xMSYEoa+mmKGHUjYjJi44wR7VPG/CGnuYIUFDnRO/TwzKDKOg3DuS7MuNzZpWVJAC\nDVrMMSOpdk5AsAepsas1KiJjjNTXOEaU47Cuxo7ZokdvCC7AL6yxhxlS9MRxqqBotU0NI2cSN9bL\nmlZUTHAaO7Qkrijmi5GAYpazRBmC/k5snBm3y4fTMTKODHS7vNitnqDNMAAGoxaDUUNDWGMPMxRQ\nbMTdZde1Rm/QojdoaGxw9Ve3BhzFFNMTwa68+HUjJORRkiRqquxExciZtcESHXCgjoxxCETEBBHq\n2JqoWBNNDc4RFVRwKn0W7B6PhyuuuILly5ezbNky/vKXv4SiX2FOQdG6eyLQlPubGpyI4sjYfjfW\nOzGatEFFxCgoERMjJTLGYfPgcfuCdpwqxI4wB2rAcRpkRIxCdIwRSQJr48hReE6lz4Jdp9Px2muv\nsWbNGtasWcOmTZvYt29fKPoWphVNDU7UGhWmCF2PvhcZbUT0S9itwz8xxe8XsTa6Ag6wYNHq1ETF\nGKkbIaYYRTAHa19XiGkWgL0NeWxdtvebb77ijTdeDfq7lZUVfPHFp0Hd+/jjv2bjxvXd3te6lMCa\nNe/x2Wf/C6r9qICdXR6HTz75L9XVLUdJPvnk7ykqKgyqraFKSEoKGI3yi+bxePD5RvYRXINFU4OL\nyChDj9PBFQ2/sd7ZI3vsUMTa6EKS6FVpgMgYIyX5TjxuX4+0/aGIIpCCTU5SUByHvY2MObVs7/z5\nZ7a7x+/3o1ar231eXl7GF198xnnnXdCrZ3eE4gyPjDawfPllQX9PGQfFEf+//33EzJlTSUrKAODh\nh38esj4OFiGZ4aIosnLlSoqLi7n22mvJzc0NRbNhmnG7vLhdvnZnWgZDZLQszGVTTudlTYcDgYiY\nHpqjACKjlHFw9SiyaCjS0EuNXatTY4nU98oU07ps78UXX4LFYuHIkUPcd99DPP74r7FYIsnLO8r4\n8ROZP/9MnnvuaQRBQKvV8OyzL/Dii6soKirkppuu5YILlnLllde0af+ZZ55k9+6dpKSktonaOXr0\nCH/+8zO4XC6ioqL5+c8fIzY2jnvuuQ3RFUt1XSGxH5Zht9sxmUycccYCfve7x3jpJXk3UVlZwcMP\n38+rr77JK6/8nW+++QqHw4lWSmTS9HvYsGEdR44c5sEHH0Sj0fL886t54IF7ufvu+zh8+ADl5eXc\neee9gKzZHz16hB//+AE+//wT3nnnLfx+H5MmTeEnP/npkKrBExLBrlKpWLNmDTabjTvvvJPjx48z\nZkznJUDD9IymZufnqYf0BoMiBEdCZExDLyJiFJQFztroHPaC/Vv315ROK+RP+ZsRCnomTJxjPfh8\nInnfbGgjiGYkTmXlmKWdfu/Usr2ffPLfNt8vLS3mT396AYCHH76Pn/zkp0yZkktEhIamJg+33343\nb731Ok8++f/atb1x45eUlpbwz3++TU1NDd///hUsXXopPp+PZ599iieeeIaoqGjWrfuCF19cxc9+\n9kskScLpsHPdlT9l6VXTWL36bwBkZGTi9/uoqCgnJSWVdes+55xzzgPgssuu4oYbbsbn9XPj9+9k\n566t3P/Idbz33tv88pe/ICGhbWGwRYvO5fbbbwwI9nXrPuf6639IUVEh69Z9zgsvrEatVvPHPz7J\n559/wvnnX9Sjv0V/EtI9aUREBHPnzuWrr77qVrAnJAyPioNDoZ/VzdUMU9OiO+1PZ58b9HLFO5fD\nNyR+S1/64HHKCUaZ2fE9bidttLxb8fukbr87FMapK9xuH6oIAU0v6uyrNSp8PhEkUKtbBLPJqOv2\nd6tUEBdnJjragsViwNj8HYNBy8KFSwPfP/30uTz//HMsW7aMJUuWkJSURHS0CZ1O0+Ezjh07wIoV\nl5KQYCEhwcK8eWcQGWnEZquhoCCfBx+8F0mSEEWRxMREEhIsCAhkpE4nMTmShAQLZrMes9lAQoKF\npUsvZuvWTdxyyy1s2rSeZ599loQEC7t2bebll1/G6XRSVV9FWXkaCQkWtFo1ktQyL7RaNTExJsaO\nTSczM4OKigJGjx5NeXkpixcv4I033uD48WPccceNSJKE2+0mLS15SM2bPgv2uro6tFotFosFl8vF\nli1buPXWW7v9XnX10C9OlZBgGRL9LC2WDyJWaYQO+9NVPyVJQqNVUV1pHfTf0tfxPFkhn5QjIva4\nHalZhlWUNnb53aHyN+8Mr9dPbN5YZo45gwvPn9Lj7x/aU87GT49x9sUTmDA1uc217n63KErU1trw\netVYrS6cTg/V1VZcLi8+X8vcXLHiGqZNm8uWLV9z5ZVX8swzq2hocODx+Dp8htPpwWZzB6653V6a\nmpzU1dnIysrm+edXt+uny+VFY9Gh0amorrZit7uRJDXV1VZOO+0sHn30p8yaNQ+/X8JojKGsrJZf\n/erXrF79OvHxCTx8/2+wNjgpL6vH6/W3+f1er5/6egfV1VYWLDibd99dQ0ZGJvPnL6S62orV6mTJ\nkou47ba7ejR+oSDYxaPPUTHV1dVcd911XHrppVxxxRUsWLCAhQsX9rXZMK1QzCi9McUIgkBktJHG\nBuewzzhsqHNiMut65fxsbYoZziip8PGJPQvxU1Ac6LZ+DPUrKyslOzuHa6+9nilTplBcXIjJZMZu\n79hpO23aTNau/RxRFKmpqWHXrp0AjB6dSX19AwcO7AfA5/NRUJAPtBwy0tE7MWpUGmq1ilde+TuL\nF8tmGI/HgyBAZGQUDoeD44W7geY5ZTJhs3UcMbVw4WK++mpDG5POrFlz2bBhHfX1ssLV1NREZWVl\nr8aqv+izxj5+/Hjef//9UPQlTCcoSTmW6N5FtcihfnacDi8mc8/CJYcKfr+IrcnV6yPN9AY59r1p\nmMcuK6nwSjninqIIdmtTb8YhOHv+O++8ya5dO1Cr1YwfP47TT58PgFqt4cYbv8eFFy5r4zxduPBs\ndu3azvXXX016egYzZswCQKPR8LvfPcmzz/4fNpsNUfRz5ZXXkJWVjd8vtfk9p7J48RKef/5P3HLL\nnYBsJl62bAXXXXcVKSmpZGeNw14vv1sXXbSMxx57DK1Wx/PPr27jO7BYLGRmZlNcXMiECZMAyMzM\n4pZb7uT+++9CFCW0Wi333/8QycnJHfZlMBCkQVLjhvJ2V2GobMtff/5b/H6R6++e1+H17vq5ef0J\n9m4rYcX3Z5CcNnhlV/synvW1dt56aTsTpiZz9sUTetXGO//YQUOtg5t/cmanEQxD5W/eGbu/Lebb\nDflcddOcwCEiPcHn8/PS018xKiOaS66Z3g89bEt/jeen/zlAwbEarr9nXq+UleL8Wj5+ez9zFmQy\ne0HmkP+7KwyYKSZM/6JoqpG91NahVSz7MI6MCZQS6GFyUmssUQZ8PhGnffgesqBkS0b38pg/jUaN\nyawb9lmX1kYXGo0Ko6nr4/A6I1AMrH5kZOGeSliwD3FsTW4kCSKjei/QomLkRUERjsORvsSwKyj2\n2OFsjrE1m1D6Mg4RUfrmeTV8fS7WRheWXiTsKUREGlCpBJrqh+9c6IqwYB/iBBynoRBoI0Fj78OB\n1C3JWsP3ZbY2udHp1d0e3NwVkVEGRFEatsfDuV0+3C5fnzKpVSoBc4QOm3X4zoWuCAv2IU5LclLv\nJ7GinQxrjT0g2Hs/DgHH4TBd4CRJwtroIiKyb6UhlO/3Z2RMf6LsWvpaIiMi0oDd6hmRVR7Dgn2I\n05dQRwWVSsASbRjW205rkwuDSYtW1/tAroDGPkwFmsftw+vxY4nU96kdRSAO13FQ+t3bKDGFiCh5\nHEdCgbxTCQv2IU6LYO/bJI6KNuJyyjVnhhuSJGFvchNhCZFAG6amGGujLIAi+qipBmLZexXyOPhY\nlV1sCDR2kP1YI42wYB/iNDXI3v++xp8PZzu72+XD5xOJ6KOmqtGoMUcM34gQJfbc0kdTjPL94TYO\nu3fv5KGH7gv0uzNTzD333MbRo0e6bU/Z+diaXPzpT39i587tverX22+/idvdsjg89NCPsdsHt0R0\nWLAPYSRJoqnBiSW6995/BWXbaRuG207lRY6w9L3ssCXagK3JNSztqoqG3dcFztI8F4abYAcQBLoV\n7MGiaOyNDU7uvfdeZs2a06t23nnnTdzulrF86qlnMZsHt9Dc8C5MPcJxu3x43H5S0ntvX1dQzBjD\n0Z6oLEZ9FWggh41WljZht7r75LcYDBRTTF8FmlanwWDU9Eiwu1wufvnLn1JdXYUoilx//c0sXnxu\np2V1y8pK+b//exybrQlJEvjtb58gNXUUq1Y9x9atmxEEFddddxPnnHMeu3fvZPXqvxEVFU1BwQkm\nTJjIo4/+FoBvv93Mn//8DNHRMYwdOx6ApkYnGq0qEBnkdrt5/PFfU1RUSEZGBh5PS7TP9u3f8vLL\nf8Pr9TJqVBqPPPIYBoOBK664hLMXXcAXm7/Er1vGV9veZNas09HrDfzvfx/xm9/8AZB3Cf/+9xs8\n8cQzPP30Exw9egi3282iRedw00238u67b1FTU80999xOdHQ0zz33PFdccQkvv/xP3njjNZKTU1ix\n4nIAVq/+G2azmauuupZ//euffPnlF3i9Ps46axE33dR9fa2eEBbsQxjlxeurLRFaBPtwtCfam0In\n2C2tQh6Hm2BXNHb/hv+yY9WePu065tg8iH6J/IffBsAyew4JV1zd6f1bt24mPj6Bp556FgCHw95l\nWd1f//oXXHfdjaxYsZTy8jpEUWTjxvWcOJHHa6/9m/r6Om6++TpmzJgJQF7eMV5//R3i4uK4444f\nsn//XsaPn8hTT/2eP//5RUaNSuOXv/wZ0D6Gfc2adzEajbzyyr84ceI4N910LQCNjQ28+upqnnvu\nr+j1Bt5441Xeeut1brjhZvk3R5pZMu8uRqfHcrSkVB6XOafx9NN/wO12odcbWLfuCxYvXgLAbbfd\nhcViQRRFfvSjO8jPP87ll1/Nv//9ZqCcsYzcr3PPXcJzz/0xINjXr1/LM8/8me3bv6W0tJiXXnoN\nSZJ4+OH72bt3D9OmhS4TOCzYhzC2EAo087DW2BUTRN8XuMCBG43D7+ARa5MLlUpAq1XTVxe4oBKQ\n/CKSJJs3uiM7ewyrVj3HCy/8hTPOWMC0adPJzz9Bfv4J7rvvruayuhLx8Qk4HA5qaqpZsEAuBqjV\nypr1vn17OPfc8wGIiYllxoxZHD58CJPJxKRJk4mPjwdgzJhxVFRUYDAYSU0dxahRaQAsWXIhH3zw\nH3kXm9ayKO/Zs5srmhelnJwxjBkzDoCDBw9QWJjPHXf8EEmS8Pl8TJkyLfC9JUvO57//ymuj7KjV\nak477Qy+/vorFi1azJYtX3PXXT8CYN26z/jwwzX4/X7q6mopKCggO3sMIDX/pyD//9ix42loaKC2\ntob6+noiIyNJTEzinXfeYvv2bdx007VyXXmni9LS4rBg/66gCGFzH6NBWrcxHCMhrMoCF4JxaHEi\nD79xsDW6MVv0JF55NQl33dKn2ibfrDvOvu2lrLxuJkmp3Z/MlZ4+mpdffp0tW77hxRf/wty5p3PW\nWYvIzs5pV1bX4ei4iuOpma6t/60IfwC1WoXf3/HS5fPKu5RTzVGtfVBKu5IkMWfO6Tz22O86bMto\nNBIRaWj3TixefB7/+c/bREZamDhxMkajkYqKct566w1efvmfmM0RPP74r/F4uleSzj77HL78ci21\ntbWcc86SQL9+8IMbuOSSFd1+v7eEnadDmIBtOQQCTa2WD8Iejs5TW5MbQQCzpe+VKZXdj32YmaT8\nPhGH3ROyc2t7GvJYU1ODXq9nyZILuOaa73Ps2NFOy+qaTGYSE5P46qsNAHi9XtxuF9OmzWTdui8Q\nRZH6+nr27dvDpEmTO31mRkYmlZUVlJeXAbB27Wf4mmuntx6H6dNn8PnnnwCQn3+cEyfyAJg8eSr7\n9++lrEw2s7jdLkpKits8IyJSj8ftlw8faWbGjFkcO3aUDz9cEyjVa7fbMRqNmExm6upq+fbbzYH7\nuypJvHjxeaxb9zkbN67n7LPPAeC0007n448/xOl0No9tdaAEcKgIa+xDmFBq7CAvEDVVNiRJGlLn\nM3aHvcmFKUKHStV3PSSwcxlmC5xijuprcpKCEvIYbJJSfv5xVq16DpVKQKPR8sADP+uyrO4vfvFr\n/u//HueVV15CENT89rdPsHDh2Rw8uI8bbrgGQVBx5533EhMTS2FhQZtnKXNTp9Px4IOP8OCDPyI6\nOobc3OmcrKiT+99KsC9ffjmPP/5rbrjhe4wdO45Jk+QDSKKjo3nkkcf41a8ewePxIggCt9xyB+np\no1Hs4Ip5T1kwQD7qc968BXzyycf84he/BmDMmLGMHTueH/zgKlJTR5Gb22LSueSS5TzwwL3Exyfw\n3HPP07q8cVZWNg6Hg4SEJGJj4wCYM+d0iooKuf32GwEwmUw8+uhviYkJnWkwXLa3Cwa7lOcH/9pD\neXEDtz54FuoujkELtp99LXXaV3oznqIo8dLTm0hIsbDyBzND0o9X/vQNOr2G7912Wkj6OBCUFtbz\n0Vt7mTUvg7lnZfW5nzUnrbzzj51MmZnKmUvGhbCnbQn1eH79RR77d5Zx+Q2zSEju+1F0u7YUsXVj\nAVf/cC4xCb2vQzRQhMv2jgDsVjdGs7ZLod4ThmPIo8PuQRSlkJijFMwWPXbr8KpuGKr6KAqBujnD\nLJY9lKGvcjtKlNTwS9zrirBgH6JIkoTN2vc0+tZERA6/kMdQJeW0JsKix+cTh1V5hUCSVojGQT5R\nSh1wTA8X7FY3KrXQp+qWrVHGczifVdARYcE+RHG7fPh9Ysjs69CqNsYwKlVqDziQQ6OpwvAM/VQW\n41Bp7CDb2a2NrmG1c7Fb3Zgj9CHzEQV8DcO48mlH9FmwV1ZWct1113HRRRexbNkyXnvttVD06zuP\nLYQhfgrDWaCFUmMfjg5UpU5MSOdDpB6vx4/X4+/+5iGAKMqRQaGIjlIwRegQhJGnsfc5KkatVvOz\nn/2MiRMnYrfbWblyJfPnzycnJycU/fvOEuqIGBie2afWfjDFBBY42zAah0YXRpMWjVYdsjbNES0L\nvU4/9APkHHYvktTS71AghwHrh/VZBR3RZ409ISGBiRMnAmA2m8nJyaGqqqrPHfuuE8oYdgVThKzp\nDCfB3qKxh84EEXAiD5NxCPhbQjgGMPwWuP5QdkAOIW1qdCGKw8ck1R0htbGXlpZy5MgRcnNzQ9ls\nv2I/sA9nfv5gd6Md/TGJ1WpV83Fg7V9kT2UFjsOHQvasUKE4y3p7aHFHKFv51uMgSRKi14vo9SL5\nhpZT1enwIvqlkO5aAMzNC73d2lI0S3S7se3Zjd/WtuyszWbj/fffDfxbKaHbEU8++XuKigq7fX5X\nbbRGKcMbeCeC0NhffvnFoMvwRkQakEQJR/MC9/bbb+JyOrHu2IansmJIlOHtKSHbf9ntdu69914e\neeQRzGZzt/cHG4/Z3xT+8xU8dfWkLL2IjB9ci1rfdtIMVj+V1OnRmXHExoduPKNiTVSWNRIfFwGS\nSPlHH1P15QYchUUApF9zFaOvvrL3HQ9RPxUcNg9R0UYSE7tPew+WSItcVsDr8ZOQYEH0ejn8uz/Q\nsGcvxwFUKjK+/z3SLuu/lO+eUOlpBCA+IaLN+PV1bqamRcv/I8lt+Z1ODj3zJE2HDiOo1URNyyV1\n6UXEzJqJ293IRx/9h1tvlZNqoqNN6PWaDvvw9NNPtPm3co8oim2SzLpqozVarZqYGBP2etlhmjIq\nqsvviKLIT3/6QPcD0ExisoXjh6vQqNUkJFh49+03mJF/HOn4CZIvWMI//vFy0G0NFUIi2H0+H/fe\ney+XXnoGQVm+AAAgAElEQVQp5557blDfGSpJIEm33U3l6r9R8dHH1Gzbwagf/wRdQiIwuMkqNVXy\nc90eb7d96Ek/DUYNol+iuKgW19frqHn3bVCrMU+bjqesjJI3/43D4SFu2aV9/g196SfIafQ2q5vU\n9KiQ/x10ejX1tQ6qq61UvfUGDXv2ohuVhikhDmt+AUX/fANfXDLmyVNC+tzeUFoip5urNEJgHEIx\nN33N1SGrKps4WVpD2XPP4Dx2FOOEiYgOBw27dtOwZy8Zj/2Gx//2V4qLi1m27BJmzz6NM86YT0ND\nE7fddme7Urv33HMbd999H+PHT2DJkrO46qpr2bbtW+6++8fY7fY2ZXg9Hl+733FqGV673Ul9vYP6\nCh8V1cf42aOrEVRSuzK8F198Cdu3b2XlyivZunUz8+efGVQZ3sYGG3GWCZQUTeTdF/8fVSdP8ugX\nnxEdHcOq85exaNHZg16GVyHYxTwkgv2RRx5hzJgxXH/99aFobkAxZmeT8cvfUPOfd2hY+wU1775N\n6h13D3a3sFvdGIyhdZZBS9hgQ1k19o8+QB1hIePXv0MTFYW3tpbS/3uC2g/eR9DpiD3/wpA+u6co\ntt9Q25ahJUnJumsnDWu/QJeSyuhHHiUpLZ6SbXspfuL3VP79RTIe+y2a6OiQP78nKONgajZBbF5/\ngsK8GsQ+HhaimJSP7KvEsXsHWXlHiZg9h5RbbkdQq7Ht3kX5qj9R9cY/uf32uykszGf16jcAWUB2\nVGp36tRpbZ7hdDrJyRnDD394Gx6Ph6uvXtGuDO+pdFaGt7qqhgN5a3nxpb+RkBTdrgyvTqdn1aqX\nALnMMARXhvfEkZM89PC9HDt8mNOLi3lPq+WPj/6G1IVnN4dVDn4Z3p7SZxv7zp07+eijj/j2229Z\nvnw5K1asYNOmTaHo24Ch0ulIuOp76DOzsO3cgfuUQkEDTX8kJykoNvvyz9Yjud3EX34lmqgoALRx\ncaQ9+DDqqGhqP3i/nZ11oOmPUEeFCIset8tH+auvIOh0pNx+F6pmM5whK5uEK67Cb7VS8dILSOLg\nnrak2MAVm3ioUELBRb+Ir64Oc+40Um6+DUEtKxMRM2ZinjET57Gj2Hbvavd9pdSuIAiBUrunotFo\nWLhwMQBFRYXtyvB2xJ49uwPXWpfhPX7iCI22kzz007u48cbv8emnH3Py5MnA95SCXa1pXYbX7/ez\nZcvXnHmmXE543brPuOmm7/PL395Do/Ukx3ftQLTZEIwmLDNntYqVb1+G9/jxvEAZ3m3btgbK8N50\n07UUFxdRWjq4MqTPGvusWbM4fPhwKPoyqAiCQPzyFZQ9+wy1H35A6l33DFpfPG4fPm9ok5MUApl2\nReUk5owhct78Nte1cfHELDmfmnf+TeNXm4i98KKQ9yFY+iPrVEFxwDk9kPW9a9GPGtXmevQ55+E4\nchj7nt3Y9+4hYkZo6tT0BsWpp8yHeYtzuPSq6SExT73+/Ld4GxsZW7uDhB8/jqBpKxISr7qGwoMH\nqPv4w3YLXDCldnU6Xa+SiToqw+tyekhLnsA//vFih98xGjs+OKW7MrySX8Odt/4Ya1UtKrMZVSft\nwOCV4e0p4czTVpgmT8WQnYNt905cxUWD1g8lWsPcLwJN1vpcmggSr/0BQgcVE6POPAtBr6dh/dpB\njRCx2xRNNfTjYDLJWqkvbhSR889sd10QBOIvXQlA49eDuwPtL40dwKgRcUtajJNz0aWktruujU8g\n9qKlaB0ObLW1PW6/dVZrR2V4O6KzMrwW4yiq6gq6LMPbEd2V4XW6rZRXH8EraIg9/0LM5oghV4a3\np4QFeysEQSDuUnnVrf1wzaD1w94PMewK2qZqAPyJ6RhGZ3R4j9pkJmr+Anz1dR1uwQeK/opbBlDX\nyMJFGJ/b4eIGoE9PR5+ZhX3/PnwNDSHvQ7DYbW40GlW/JBFprDVIggrDWZ0HPcScfwGRlkhydDqu\nu+5q/vrXP7W7p7WG3dn/63Q6Hnro5zz44I+4665bSOlgIQG5DK/D4eCGG77Hm2++zqRJU/B6fKgF\nI8vOv5lf/eoRrr/+Gm677SaKAwpY57sCpQzv1q1bmDdPXsRbl+F96onfkBSdgU+tJ3rxOYEyvD/6\n0R3t2u6sDO95553P7bffyPXXX82jjz6M0+notD8DQbhs7ylIkkTJE7/HdeI4s/72PFbVwJ+LeWhv\nORs/OcbZF09gwtTkbu/vSYTEybfe5D8FSSTGaLns9vaaqoLnZCWFP/8phpwxjP7ZL4Lue6j6CfDZ\n+wfJP1rNdXefEXKtffsfVrFDmMycuUnMXjyx0z42bFhP1euvEb/ycmIvWhrSPgTLK3/+Bp2ubZnh\nkETFNNTz6ZNvURI1kcuun0liSuchpZWvrKbp602kPfAwpgkTO73vVEIVWVZfY+etv29n4rQUFl04\nvs/ttWl7/Vo+3tSI0xTLzQ8uGtJnFYTL9vYSQRCInLcAgLqt2walD/Z+qBMDIIki9p3b0IsunGLX\n2p8uKRlz7jRcJ47jzD8R0n4Ei8Mun5xkNIXWBOEuL0dVIv8mp6/rqCPL3NMRtFoav/lqUIpl+f0i\nTrs3kDUcShq+XI/eKzvIFbNXZ0SedjoA1m1bQ96PYLDb+m/3Ztu1E53fgU8Uhk3dnO4IC/YOiJg+\nAwSB2i3fDsrzbf1kgnDmHcNXX49Rr8Jh93QrqKIXy9tza6tjwAYSu9WDyaxDpQqtBtX09Sb0Pnvg\nGV2hNpmImDUb78mTOPOOhbQfweC094+fQZIkmrZuwaCSfSjdFYYzjp+AOioK687tg+J3USKkQlkA\nDMBvteI8dpSICNkRPFzKK3RHWLB3gCYqCuOYsTQdPoKvqWnAn99iYw/tJLZukxeqiPhI/H4Jj7vr\nF9Q0YSIqgwH7/n0Drq1KkoTD7gm5pir5fDRt+Qa9SYtaLQRV4TFqwVkANH61MaR9CYYWB3Jox8FT\nUY6vpoao0SmAnOHbFYJKhWX2XES7HfuhgyHtSzD0lyPdtncPiCIxaXJSYncL/XAhLNg7IWLGTJAk\n7Ht2D/izbVY3Or0arS50zjLJ58O6cwfqqCgik2SnT3eTWNBoME2egre6Gm9l+xjl/sTjluvRm0L8\nIjvzjuG3Womae7qcpBSEhmYcPwFNfDz23bsGXFvtLweyfd9eAGInjmnznK6wzJVt/IqCMJD0V0CB\nbfdOABInZAItoaXDnbBg74SIGbOAlj/8QOKweUKumdgPHUS02bDMnoup+eVw2LufxObmTEJbsyAY\nKPorxM9+8IDcbm4uZoseh82Dv5sMTkEQME/JRXS5cBUMbME4RZMO9c7Fvm8vCALxM+SSCcEscIbs\nHLTxCdh270Z0D6wA7I8FTnS5cBw8gG5UGjHpSfJzutm5DBfCgr0TtAkJmLMycRw+hN85cLWa/c1H\ntoX6Rbbtkhcoy9zTWqr6BTGJzVOnyvfu3xfS/nSHsuiEWmN3HDyAoNFgHDs+ICQUO3ZXmCdPBhhw\nM0TAaRjCcfA77DiP52HIysIQG41OrwlqLgiCgGXuaUhu14BXArXb3KjVAnpD6Hax9gP7kXw+ImbM\nDJykFLaxfweIPf00JJ8P+/6B01Zb6oKEVrA7jxxGZTJhyMoOtN2dXRVAExWNPjNLNmEM4ALXHxq7\nr6kJd0kxxrHjUOn1LQePBGGGMI6fCCoVjmaNf6DoD03VcfAgiGJgN2a26II+VcvUXBTNcWSABbvV\ng9kSuiPxoGU3HjFzVuDIwWDeieFAWLB3QdzpcwGwD2CSjqNZezSZQ/cie2uq8dZUYxw3HkGlajk5\nJ0jtxDw1F/x+HIcGTqgFxiGEgt1xWNa2TZNk4WQyB7/AqU0mDNk5uAry8XeSldgf2PvBFKPY1825\nzYI9Qq6b4/N2H+pnyM5B0GpxHDkSsv50h9/ffCReCHctkt+Pfd9eNHFx6NNHN5+jGjbFfCcwZWSg\njo7GcfTIgEWFOPohCkJ5CZXEkp4INICIZgFg3zdw5pieHKoQLIq2bWo2q/Rk5wLIJXwlaUC1VbtN\nPrZOG6Iqn5IoYj+wD3VUNPrmzOOWk5S6HweVVotxzFg8pSX4rAMTMRYI+QzhrsVdXITodGKeMhVB\nEFCpBExmXdh5+l1AEARM48bjb2rC26qKXH/SH84yx1G5SJsi2I1mbY+0E31GJmpLJPb9ewes0mGo\nw/wkScJ+8CBqiwV9Wnpz280CLQgnMoBpkrwgOAbQzu6whfbwZldhAX6rFfPU3IBZo+UkpeDGwdg8\nj5xHj4asX13RktcRwnfimNx347gJgc9MEXrstu7zO4YDYcHeDcaxcvqy89jATGJFyChadV+RJAnn\nkcOoLRZ0qXIFQ5VKhdEUvHYiqFSYp0zF39SEp6wsJP3qDiXr1BCirFNPeRn+xgZMkyYHasP0VGM3\nZGahMhqxHzwwIC+/z+vH7fKFdtfSvCgpTnHo+dmnioLgODIwVV0d/RDDrrzPxrHjAp+ZI3T4fWK3\n+R3DgbBg7wbjOFmwO/IGRrAHJnGItp3eqpNytun4CW2KXZkidEFlnyooL4DzeF5I+tUdoc46DZhh\nJrWciNRTk5SgVmOaOAlfTQ3eATiwvT+Sk5S/n6KwyO03C/Ygk3MMGZkIegPOgRLsIfa3SKKIM+8Y\n2oQEtLGxgc+VMOCRkKQUFuzdoEtJQRURMWAae8AUEyKNXdGqTi3cZI7Q4fOKeNzB1cYwjh0LDIxg\n74+sUyVMUQlbBNDpNWi0qh5FQgSiQgbAkRyIkArRIi+JIq4Tx9EmJaGJbCn4pZg4gtXY5XDRcXgq\nK/A19H952lC/E56yMkSHo83iBq1MUiPAzh4W7N0gqFQYx47DV1uLt7am35/nsHnQaENXotXZiWBX\n4sODSVIC0CY3L3An+l+whzrrVBJFXMfz0CYno4mOaXPNZNYFbWMHMI1vti/n9f84hNqR7ikvQ3Q6\nMeaMbfO5orH3xHFomiDbph1H+z865tSjAfuKsvtWduMKph7kdwx1woI9CEwBO3v/F4Gy290hsyVK\nkoTjyBHU0dFok9qW/+2xGUIQMOaMwVdT0+9aWqhj2D3lZYguF8bsMe2umSL0uBxeRDE4k5Q2KUle\n4PKPh6RvXRHqyKCAGWZMW8FuNMsFsHq0c5kwSf7OAJye5rSHVmMP2NfHnaqx93yBG6qEBXsQKBPA\n2c92dlFsLtEaqi1nRTl+axOm8RPbJXa0bL+Df5kVgdDf5phQZ50qZYcNOe0FuzlChySB09GDBS47\nR17gGvv38I1Ql6pV/m6GUwS77EzXYg8iA1dBP3o0KpMJ59H+F+x2m6f5oJG+h3xKkoTz2FFZ2UlI\naHOtJToorLED8MgjjzBv3jyWLVsWiuaGHPr0dFQGQyBEqr9w2r1A6JxErhOyVqnYx1ujJED1RDsZ\nKMEeao3ddUIW7MacnHbXerpzATlJB8DVz3XqQ+08dR0/jspsRpfc/vAWU4QuqNIKCoJKhTFnDN7q\n6n6vgKr4W0KRdeo9eRJ/UxOmcePbtWfqYeLeUCYkgn3lypW8/PLLoWhqSCKo1RjGjMVbWYmvsbHf\nnhNq779SsEoRRK3paagfgD4zE0GjwXm8f80QIR+H/BOoDIZAuGdrejMOxmbN33mifwW70idjCHZw\nvoYGOfs4Z0yHRwGazDo8bj/eILJPFQxZ2QD9WhhNFCWcdk/ozTCnOE4BjCYtKpUwIsoKhESwz549\nm8jIzo/VGgmYAuaY/rOzh7rgkzM/H0GnQz8qrd21nhQCU1BpdegzMuWsvX6s7hdK27LfbsdTUY4h\nK7tjgdbDJCUAQ1YWCEJgR9RfOOweDEYtanXfX9PO7OsKyjj0RGs3ZCuCvf8WOJfTiySFbpFX3l/j\nuHHtrgmCIIcBjwCNPfSn445QAtvvgnwss+cAYPPa2V65G5WgIjMyndSIFLSq4IZUkiRKqmwcKqzH\n6/Oj16pxVcs1SEKhnYhuN56yUoxjxiKo29smjQETRM8msXHMGFwnjuMqyA9E2tS7GjhQewS7186M\nxFySTAndtNKCy+OjqNJKQYUVm9NLXJSB6pPyGZmhMEEoQqejXUvrZ/RES1MZjOhSR+EqKkTy+RA0\nGlw+F2W2SmpddVg9NqbGTySxB+NQ3eCk+KSVmkYXjXYPybEmbFY3lsj+ta8rKHPObvMQGR3cOb+G\nTEWwFwQ+a/JYOVhzBLVKjU6tY5pxLALB/4ayGjvHShpweXx4vCKG5jyLUNVOchacQGU0ouvkIG1T\nhI6aShuSJA3ps0+7Y9AEe7CHsg42Sj995qmUCgL+smI0ESJrDn/G+vxvcPtbBIJereOHs65mUdYZ\nnbbncvt4d30ea7cXU9voanMtFRiFig0HK7GkRzNtbPCC4dTxbDxYDJJEzKTxnY61KUKH2+Xr0d9C\nNWsa9Z99iqqiGPuUFJ7f/k8K6ksC1z/K/4zxcdlcPuVipiVP6rSftY1O3vz8KGu3FeM/JSJlAgIR\nCHyxr4JLzxpDQkzvDxR3VpYCkDRzKrEd/E7RKz9b8kuBvgUzHo1TJnLys1JM9joKLV6e3foyTW5b\n4PoHJ/7HOTkLuHzyxUQbOt7NSpLEoYI63t9wnG2HKmmdKyYAs1FRWu9k8+EqLpqXiVbTdoHuyd+t\nvCgfQaMhbfZU1Pr2QjIxSW5Lq1YF326ChbKUZNyFBcTGGvmycAtv7H0fu7elCqjmoIZrc5dz4biz\nUQkd7zx8fpEvthbxxbZi8kraOqQjgfGo2Ha8moRJSSyYntprgeuz2zlWWUlU7lQSk6La/5wECzGx\nZqrKrUSY9CEvGT2QDJpgD8XJ5f3NqSes65JTsB4/zs8+e4I6dwMx+miWZi3BrDVT2FTCjpO7+eu2\n1yisKueirPPaTEBJktiTV8O/1h6jtsmN2aDh9MlJ5GbHYTHpcHv9HPy2GFu5lX2FdWx9YTNn5qZw\n9TljMXYT097RSfB1u+WEHCk5vdOxNpq0NDW4evS38CXIduqSndt5Ub0Zl8/FxNhxTImbiElrZGvF\nTo7WHucPm1Zx69TrmBrfItwTEixUVDby4TcFfLatBK9PJCnWxPQxcWSlRBJl1lHX5GbfF3l4PH4+\n2JTPf78uYMVZ2Vxw2mhUvXiha/fLBbs8cakd/k63V3ZY19bYqa62djiWHZI6GoAvv3iP12MLUQkq\nFqbNJ9mUiFqlYm3RRj4/vomvCrfx4xm3k2ZpqyHaXV5Wf3yY3XlybkRWSiRzJiQSH2Ug0qyjsKSB\nE5sK8UgSf//gAO9/mcfV54xj1viEwFgG+3cT3W5s+QUYMjKpa/IA7XcnIvKqUlHeSHxK8AuGdnQW\nrq1b+MM7v2efUIlBrWdZ9gVEaE04vE6+LPuKV/e8y9aivVw/+WoidW3bLqux8/f/HqKo0oogQG5O\nHLPGJWAx6dBpVRzeW0HV4WqqrW6een0H//06hmvPG0dKnDnoPiooNeRVqe3fCWU8NVp58SkuriMu\nIaLHz+hvgl10QybYR0LhnO7QjE7HU1GOVFXDBdPO56LMc1GrZC3qtJRZLEybx1/3ruZ/hWupdzdy\n7YTLEQQBUZR4Y+0xvtxVhlolcPEZGSw9IxO9rq0GdnJfJTbgnqum8eaXJ/hqXwWHCuu5fflkclLb\naxhdoURsKHbQjjBF6KmtsuP1+II+hk9jiUSKjcZZkI97ZgLXTbqKuckzA9fnJs8krz6fv+59mb/v\n/ye35t7A5DjZP9Fk9/D/3t7L4aJ6Yix6li/IYt7UZNStbN+SJLH/02MkJ0bww9mjeG/jCd7dcIJj\nJQ3cvHQSEUZt0GMgiSKu/BNok5JQR3T8khqMvXOYGZtNO1VH9hC5KI2bp/6A7KjMwPXTk2ezsWwz\n7+V9xPP7/sFDs+8hSi9r7gUVTTy/5gA1jS7GpUez8qxsxqZFtVEEIlUCJyhk/oxR5Khh3c4yVr2/\nn0vmZ3LJgqwe9dVdUgx+f9dzQTHN9cDGDqDPysK6dQuewgJy58zmqvHLida3zNWLpy7iua//wcHa\nI/xt32vcN/P2wDuzbmcp/15/HJ9fZP6UZFYuzCHmlNBOZ7mVqsPVfO+C8aw7Ws3+/FoeW72dW5dN\nYvaExB711VUom4wMWZ2PnzIOTrsHgt8wDzlC4jz9yU9+wtVXX01BQQGLFi3ivffeC0WzQwqv38s2\nXSUAC8VMlmYtCUxQhWRzIg/MvovRllFsqdjOloodeLx+Vr2/ny93lZGWEMGvb5rLZQtz2gl1kF8q\ntVpgfGYsj14/m6XzMqizunj6zT0cKqzrUX9dBfmoLZFoYuM6vcds7rkDtdZZzwmLG6Nb5Oa0S9oI\ndYWxMdncnnsjgiDw0v5XKWgspqzGzk+e28jhonpmjI3ndzefxpnTUtsIdWjJOjVb9MyfmsKvbpzL\n5KxY9p2o5TevbKemIfjDPjyVFXKmZQeJSQqCIGDsRbnWfK0Nl05gVK3Iw3N+3EaoA6hVahann8ml\n2RfS4G7kxX2v4vF72Hm0isf/uZPaRheXzM/koWtmMC49up15QVlooqMMXLV4LI/dMJuEaAMfflPI\nX98/gKsHhapcRYUAGDK6EGi98DUA7NLLO46JNjO3TP1BG6EOEG2I5I7cG5mdNJ2CpiI+yP8ESZJ4\nb+MJ3vjiGEa9mrtXTuWHSye1E+rQ4sxNSbLw4ytyuXP5FNRqgefXHGDtjpJ293dFQLBndqXs9G4c\nhhohEex//OMf+frrrzlw4AAbNmzgsssuC0WzQ4qPC77ggFGO1811xXRq54vUWbhl6nUY1HrezfuQ\nJ975ht15NUzMiOGn184kNb7zLaTdJod1CYKARq1i5Vk53L1iKn5R5Nl39rEnL7iSBr6GBnx1dRiy\ns7u0R5osPZvEkiTx5tH3qIyRp022tXPn5vjYMdw85Qd4RR+vHnybJ/+1g8paB8vmZXLXyqmdmpdO\nTaOPNOu478ppLJ2XSU2ji6fe3E1NY3DCPbBr6SB+vTXmCB32HhREs3nsvHbk31TGa7FYvZjdnX/v\nvIxFnJ48myJrCX/Z9i9e+OAgGo2K+66axvIzszstcnZqyOeohAgevX4OE0ZHs+tYNb//xza8vuBC\nE92FhYBcfrkzeqOx76s+yIeu3fhVkNOk79SGLggC14xfSaIpnnXFm1i1di0fbykiMcbIo9fPZua4\nzlXj1rH8giAwe0IiP/3eTCLNOv61No/3NgYfkeMqKGhWdmI7vUcJKuhJstZQJJx5GgQn7VWsL/kK\nX1IcqFS4iwq6vD/WEMOKMctw+92UGzczd1Ii9105DVMX5zVKUnO87ikOmxnjEvjRFdNQqWDV+/vZ\nd6K22/4G4tezOtdMAMzmniVkbKvcxeG6YwGNx11U1OX9U+InMit+FtWuKlxRedxxWS4rzsru0lau\nCJbWsdsqQWDlWdmsODNLFu7/Ck64K9Ea3Y2DyaxD9Eu4Xd1rwZIk8caRd2n0WIkeI0cFKZpgRwiC\nwDUTVhKvTeaE8xCaqHruv3IaU7I630lBx4WvIoxa7r9qOtPHxLMnr5oXPjiIr5uDuEHW2AW9ocPE\nJIWeFkRz+py8ceRdVFodmrQ0vKWliN7Ov2vQGPjh5O+jktQckjaQkqziZ9fOJD6qa8d4R+WbM5It\n/PwHs0iKMfLxliI+3VrcbX99TU346moxZGV1qewoCoUzrLGPbCRJ4p28D/FLflZOvBR9Wjru4mIk\nX+dCQJQkDuw04a9PQB1Vx4QZTWi6iUV2OeV6JR3F607OjOX+K6ejUgk8/8EBik927TQLVrD3ZNvZ\n5LHybt6H6NU6zjvjGvk5XQg0gEa7h6Nbk5G8OvTp+czO7d7x4+iiLsiy+Vksbxbu/+/tvTi6EcTu\n4iJQqzuM429NT8Zhb81B9tUcZGx0NhNzF7Y8pwsKym1U7pPNIElTCsgZ1X3OR2dJWhq1ijuWT2ba\n2Hh259Xwj/8d7nKnIbrdchz/6NEdxvG3xmTWBa2xf160AZvXzgWZ5xA1Zjz4/biLuxawhw77cBWN\nQ9B4mTCnhqggok4cNg9GU/vyzfHRRh64egbRETre/vI4Ww5UdtmOMle72rVA730NQ42wYO+GvTUH\nOVx3jImx45iWMAVDZhaSz4e7vPMDJ9798gTbDlUxyn0GBrWeTwq/aBMW2RHdnZw0Lj2aW5ZOwuPx\n8+w7e6lrcnV4H7QW7F072QICLYhJvOb4/3D4nFyacxEJcaloE5PkOO5OhIrXJ7Lq/f1U1/qZrJ+P\niI+Xd/272+d0V6L1kvlZLJmTTkWtg+c/OIC/kxOdJJ8Pd0kx+lFpCJquHcPBVroUJZGP8j9DQDYt\nKDZrxYbdETUNTv7yn/34bdGMi5hMtfsk31bs6PI50PU4aDVqfn7jaeSkRrLl4En+u7nz57uL5bBX\nfWb3DldThB6n3dNtQbQ6Vz1flnxFtD6KxekLMGQpOR6dL/Q7j1bx7/XHMTtyiNPHsa1qBycd1V0+\np7vyzXFRBu6/ajomvYbV/zvMwS78UO4gHKcARlNYsI94vH4v7+V9hFpQc8XYSxAEAUPzC9LZJN5y\nsJJPtxWTEmfivhWncXb6mdi8djaVbu7yWQFbYhfJSbMnJHLl4jE02Dw89+4+3B2kf0uShKuwAG1S\nMmpT1yFhwdZJqXLUsK1yF6nmZM4cdToAhsxMRLsdX017u78kSbzxxVGOlzYyd2Iid5x1PuNixrC7\n4gDHG7rW8oMpJ3Dl2WOYlhPHwYI63lrbcfanp6ICyefDkJnZ5fOgbXJOV+w4uYdK+0lOS5lFkjkR\nTXQ06sjITk1STreP597bh9Xh5drzxnL9tOXoVFo+PPEpTl/nCzO0ONI7K99s1Gu457Jc4iL1vP9V\nAbuPdSwkXUWKwzCzy+eBPA6SJO8eu+Kj/M/wij4uyb4AnVrXUlqgsOPSAkWVVv720SF0WjX3XT6D\n5WMvDCySXeH1+PF5xS7nQlpCBPdenosgwAtrDlDdiXM9GMcpgFqjQm/QhJ2nI5mvirZT56rnrLQz\nSN51YmQAACAASURBVDLLoVXKit/RJC6qtPLqJ0cw6tXcc1kuEUYti9PPxKgxsLZ4Iy5f5xqhI8ia\n00vmpLNoeiolVTZe+7T9IdvemmpEpxNDN1tOCH7b+VnReiQkLsg8J+AgU7a0rg78Det3lbFpbwUZ\nSRZuvGgiKpWKZdnny20Vru/yWcEcqqBSCdx6yWRGJZhZt6uUTXvL292jaNHKgc1dEczOxS/6+Tj/\nc9SCmosyzwVk+7l+dCa+ulr81rbmMUmS+Pt/D1FWbeecWWmcPTONaH0USzIWY/XaWF+8qcs+OZr9\nLV3ZgyPNOu65LBedVsXf/nuI0mpbu3sCAi2I+dCShdv5PC2xlrG9cjdpEanMSZ4BgDYxEUFv6NAU\nY3V4WPX+frw+kdsumUxGsoUZCVPJsKSzu2ofRU2dR7Z0ZZZrzbj0aK49bxx2l49V/9nfTuGRJAlX\nQQGa2Lg2B4x0hnK62HAmLNg7QZREPjwiv8jnjl4Y+FyXOgpBpwts7RRsTi+r3t+Pxydy89JJJMea\nADBpjUFp7cEWvhIEgWvOHUd28zb8y91tTUKK9qgfPbrb36jRqtHp1V1O4hpnHdsqd5FkSmRGYss5\nmYqgcDVHXCjklTbw5to8Ik1a7rlsKnqtHNaZHZXB5MRxHKo7SnFTaafPC/ZlNuo1/OiyXMwGDa9/\nfoyiyraC1V0s90s/OrPLdiC4sgJbKrZT46pjfuppxBlboioMGfLC4TrFzv7p1uJANNTV57SEWy4e\nfSZmjYmNZZvxdGKeCzjSgygtMTrJwg8vnoTb42fVf/bjPCUM0l1UhMpgQJuY1G1bxiAW+k8L5UV+\n+ZiLAou8oFJhGD0aT0V5mxpCoiTx9Bs7qWkO7Zw+Nl6+XxC4NOdCAD488Wmnz+rJwe4Lp49i4fRU\niqtsvHqKwuOrq8NvberWDKNgMssZ2X7fwBzc3h+EBXsn7Ks5RLn1JHOTZ7aJzRXUavTpo3GXlSF6\n5IknShIvfXQoMIFnnFIKYHH6AowaY7PW3vEWvCfHf2k1Ku5cPoUIo5Y31+ZxpJVtUXHkBaOhKc/r\n6kX+vOhLREnkgszFbcLZFE3Y3cq+3OTw8MIHB5GQuGP5FGIjDW3aWjHxAgA+K/qy0+c57B50ejUa\nbfe1t+Ojjdy8dBI+v8hf1+zH4WoxIbiKikClQp/eteMUujdJ+UU/nxauR6vSckHm4jbXlJ1L63E4\nWlzPuxtPEB2h47ZLJreJ1derdZyZdgZ2r6NTW3tXjvSOmDMhkQtPG83Jeif/+KRFqIkuJ57KCvSj\nM7p1nEL3C1yNs5a91QcYbRnFhJi2NWf0ozNAknCXtSzaH35dwK4jVUzJjm2XVDU+dgzjonM4Up9H\nqbX9jgtaFcULsk7M984dR05qJN8ePMmGVgpPixkmSMHeA9/TUCUs2DtAkiQ+L/oSAaGNtq5gyMgA\nUcRdKk/iT74tYn9+LZOz2k9gAKPGyDnpZ2L3Odhcsb3DZ/a0VG1spIHbL52MKEk8+c8d2Jrtoorm\nqE/vXmMHWai5HF78HYTN1bsa+LZiB4nGeGYlTmtzTW0yoU1KDjhQlcWt3upm5VnZjB8d0669qUkT\nyLCks7f6AJX2kx32x9HDEq3TxsSzdF4G1Q0u/v5fOUJEEkXcJcXoUkeh0nbfVncF0fbWHKTe3cAZ\nKbMD2aMKp2rsjTY3z39wEAGB2y+dQmQHv2Vh2jw0Kg3rSr5ClNqPe7C7ltasaM5e3XGkivW7ypr7\nJDtOgxVoxm58DV+WfI2ExOL0s9qZiJQdorJjPFhQx0ffFJIYY+TWZZM7DHFdPPpMud3Srzt8Xkeh\nr12h1ai4Q1F41uUFdnGKshOMWQ5GRmRMWLB3QF7DCYqaSpgzahrJ5vZpywFttaSIYyUNvL+pgBiL\nnluWTeo0RvvMUWegUWnYVLq5y5fZaAo+ZX5SZizLF2RR0+Dk7/89hF8UcRcVoYmL6zSF/lSUhcTl\naO8w21S2Bb/k57yMRe2ybEHeFYgOB97qaj7eXMjBgjpyc+K48PSOXyBBEDg/82wkJD4v2tDuut8v\n4nL0/ASp5QuymZgRw57jNXy2rQRPZQWSxxP0rkWtVmHo4gShDSXfALAwbX67a5rYOFQREbiLChFF\niRc/PEiT3cPli3IYlx7dYXuROgunJc9s1oAPtrvem8ObNWoVt186BYtJy1vr8sgvbwoqMak1gRju\nDsbB4ZWVkmh9FDMTc9tdNzSbvNwlRdRb3fzto4OoVAIPXzen0zIQk+MmkGiMZ0flbqye9v6B3pz5\nGhtpaN7FSc27OF/LLjZowa4cQhMW7COKdc2OrUsnLunwuiLYrScKeOED+bT62y6ZTKSp8wkYoTMz\nO3E61c5aDte1P4HIYfc0F/rv2Z/k4jMymT4ugX0nalm34SB+a1PQmgl0blf1ij42l2/DrDExO2lG\nh99VIi3yt+9nzdcFxEbquXlp54sbwNT4SSQa49lZtReb197mmtOhnCDVs6p6ijM1yqzjvY0nKN4j\nH9emzwh+HMzmjk8QKrGWcaKxgImx4zpc5AVBwDA6A+//Z++9oyS560PfT3WOk3ty3JyjNiqsJAQS\nCiRjHgbDRRhjHDg8Xb/jc1+wr6/TxX6PCxiuMRgso4vBZIQQKGu1knalzTnvTs6xezqHqvdHdfX0\nzHRPV3XXzG6P+nMO54jpqq7f/vpX39/3942jo/zqlYtc7pli++oaHtzdsuDz3tVyDwICL/W8Ns8B\nnm+jkUq3lc8+thFRlPjGL87jV8JeVUTEwMLRQW8OHCWaiHJv850ZN3lLQ4Ncvri7i2/98gLTwRgf\nuX8VazKc3BQMgoF7W+4iLiV4vf/IvM+12NjT2bKymkf2yae4J399iXBvD6bKKoxudQW0SqaYZch4\naIIL41foKGtldXXmI6y1sQmMRgbOX2HKH+VDB1Zk1c7SOdCyH4BDfW/O+yzfLjEGg8Cffmwn5S4L\nxw+eBtRrJpDdvnxq5Cz+WIC9jXdgMWbWuJQN5PQbZzAIAn/4/k05i3QZBAN3Ne0lLsbn2ZgLaVpc\n7rTw2ffJpqmzb+QxD65kB6HobOfjweRvdW8GbV1BmYdTh85QU27j04/M7zE7lzpnLZtq1tPl66Fr\nTmRIPqYYhY0dVTx2ZzvjvjCjF6/JjlOPumJZNocFQZgv0BJigoN9b2I1WrizcU/GewWTCUtzC6He\nPq71TLBzjYcHdub2b+yp34ndZONQ/xFi4uy5L2QePnB3B2tbKrh0sYfE1JSqYAIFRw7TXDFQEuxz\nODxwFAmJu5Lx2pkQTCZC5R5c02NsX1HFQ3vULZpWdzMdZW1cGL/CaHCmNEAsliAaSeTdJabCbeVz\n79tIXVj+zkRt5iYCmZhJzpn9Mh/qO4KAwN2N2WvLm5plrbQiMMZv37eKlU3qKlDubbgDs8HE6/1v\nzTJL5auhKaxvq+T9d3VQ4RtBQsDctLDWnI5ycvGnNTKejvo5PnyaWnsNG6rnt1JTiNfKpYwbouP8\n4Qc24bSpM6cdaJI3+jcH3p7190Ln4X13drCp2YUjMEmgok6V4xRkJcHumF8Q7ezYRaYiXvY27MJh\nzl4CIFBei0FMsMYa5vGH16mqm24zWdnfuJvpqJ+Tw2dmfabFkT4Xo8HAH7x/Ix2CbGcPVOSOClIo\n2diXGQkxwZuDR7Gb7OyY4yxM5/zNca7HnJilBJ/YVampTviB5v1ISBzqnwl9DGl0EmVibWslO8rk\nF/IHF0JZMzLnkmkR90730+nrZn31GjyO7DVNfnFsiCmTi6b4FA/snN9PNBtOs4OdtdsYC41zZWIm\nwUirsywTj+xppSE2yZiljGeOZY62yIQjJdhnopYODxwlLsY50Hxn1gJXsbjIDy7K9+yqiNHRoL5F\n5NqqVVTbqjgxfJpQfCaxphBNFWQB/cntZRiQuBiyc6l7UvW9mWK4lY3nrizaOsDIZJBDI7IA/vB6\nKw6VmxvIG5yAwBsDb836e9BfWK/TCpeVRzrkMT3fk8AXVCeoS6aYZcaZsQtMR/3srd+Z1fwwPBHk\nn5++wIhdFniG4eylBTKxvXYzZRa3XNI3IduU83GWZaJ8eoSIxcGZ4Rg/Oaiu6l0mU8yhPtneqWiU\nmXjr4hDPvd2D112DNRpE1Nip/u5m+USUblstVKABJMZGMSViTLk8PHO4S3VFzJR9OdlvVZREDg8c\nxWIws6dhfmlihe+/dJXzkxA3WamYHtE0VoNg4M7G3UTFGMeGTqX+rkcTa9PYIAAjtiq59rvKcscO\np4V4TCSajIcfD01weeIaK8rbaHRlLiIWiSX4p5+fp9con9hck5kjnrJRba9iXdVqbnq7GUxGSyUS\nIuFQrOAuRmU+OSP3pljGN35+XlXRNKvNJNfoLwn25cGb/UnNpCmzZhIMx/jqT84SjMTZcfc2gJyF\nj+ZiMpjY23AHoXiI06Pn5O/N01mWTsLvJz4+TvmqFdRVO3n+aC+vn82tsc7VTkLxMMeHT1Ftq8xq\nfugemubffn0Zm8XI6js2AvMTdHLR5m6hxd3E2bGLTIbldmip7NsCBFqkV/491u/ZjNlk4F9+dYHB\n8UCOu2bmwZ8U7NcmbzIWnmB77Rbspszmh4On+3nt9ACtdW6cHe3ERoY1N/ne27ALg2DgjYG3U05U\nPZpYK+ty54Ht+EMxvp4hIzMTc9fD4cFjSEhZbeuiJPHtZy7SM+Jn3a4NcvXTXm3vBMD+xt3y8waO\nAoX5W9KJ9HZjcDhZtbGdK71TfO+FqznLM880tS4J9qJnJDjG5clrrKrooN453x4nihL//MsLDE0E\neXB3C3fcK0eKaBVoAPsa5GbYR5LOQz00VeVlcrS384UPyxmZTz13Jecx3GY3z3KYnRw5Q1SMsa9h\nd0bzw4QvzNd/dpZoXOSzj22kZu2qWc9XiyAI3N20FwmJI8nYfj02OGUc9RtW86n3riMUSfA/fniG\nqRyOsJQpxidfd3hQFjCKwJnL2RtjfO/5q7jsZrm+fFurnKDTp635Q7nVzZaajfT7B1NO1KA/e+Er\ntUR65cqW++/blsrI/PavLuYs8JV+gkuICY4MHMNusmUMcQT46cEbnLg6yrrWCn7noY1Y6hsI9/Qg\nqTQFKmyp2YDL7OTtoRPExLgu74QYDhEbGcHa2spnHt1Ia62LQ2cGePlE9sxnBSVxr1g7w5UEexJF\nuNzVON9pKkoS//aby5y/OcHmFdX89r2rMNrtmGvr5BK+Gn/8WkcNqyo6uDp5nbHQuC6mmHBaEkZ9\nlYM/+ZCc/v8/f3aOgbHsGqviMFM0pCMDxxEQ2Nuwc961/lCM//GjM4z7IvzWgRVsW12TSoTKVbo2\nEztrt2IxmHlr8ASiJBIMROXa2xra380lPUFr38Z6Pnh3B+O+MF/58Zl56fbppNvYg7Egp0fPU+fw\nsHJOZySQW9v90y/OYzQKfOHDW/BU2LG2JHMbNJ7gYMZ2/cbAW8RjCaKReEFrQUomz1kbGzGYzXz8\n3WtY21LBiSujPPX8lQXXa7rGfmH8Mt6oj11127EY54/ntdP9/ObtHuqqHPzRBzdjMhqwtrUhRcLE\nRrSZY0wGE3sadhKIBTk7ekGnTb5PTtBqacVqkes3lTkt/ODlaxy9tPD4lBr9UQ2dqm4nSoId2Z76\n9uAJ7CYbWz2bZn0mSRLfe+Eqb5wbpL3ezR+8b2OqNrS1pQUxGCA+kbv5xVz2N8ia4JHB4/os4jkZ\np2tbK/nUe9cRjMT5hx+con8B4a5oJ0OBYTp93ayrWk2lbXb4Zjga58s/OsPAWID37Grh4WQSkqmq\nCoPTSaRXm6YKcvOFHbVbGQ9PcH3qplx72zm/9rYWIr09mKpmErQe3d/OPVsb6Rn28z9/fo5INLM5\nIt0Uc3T4FHExzr6GXfMiO/pH/Xz1x2eIxUU+976NqUggm5J5mYcZQnaiVnJy5CyTPjlRpxDBHh0a\nQopGU2vBZDTw+d/aQmudrLH+7FDmKozpzw0GoryZNItkMsO8fnaAp567gtNm4n//7S2pMFdbARuc\n8k4cHjiqi8Ye7p1dN6m63MYXPrwFq9nIvzxzMWtFTCj+FnklwQ5cmriKN+pjZ922WU5TUZT4wUvX\nOHiqn5Zal1z7Oa0LUioDNQ9tdXvtZmxGK28NHk/VAS/UFCPHLM/UqblzcwMff/cafIEo//D9k/SN\nzM/uA7C7LMSiCd7slU1DiqlIwReM8qUfnqZz0Medm+r5yP2rUgJPEASsLa3ERoZJhNT3I1XY23AH\nMLPBFTIHce8UCa93VsyyIAh84sE1bFtVw8WuSf6/H55KlV9Ix2I1YTAK+H0RDg8cxSAY2DPn1HKj\n38sX//0kvmCM333PWrantXSzNDSC0ZiXYDcIBvbU7ySaiHKm/zKgjzkqPVHNYTPxnz+yLdV16Iev\nXEPMoLkr8z/pnebixBVa3U00u2eHzx483c+Tv76Mw2bi//joduoqHanPlLkP5zEP9c5aVpZ3cHny\nGmNTXnk8eig7afPQ0VDGEx/Zislo4J9+cT6rcz1XeYXbHV0E+6FDh3jooYd48MEH+da3vqXHVy4p\niq17X1LIAATCMf76X9/mpRN9NNY4+dOPbpuXfKMkwITz0E4sRgs767YxFfEy4Z1esPZ2LhKRCNHB\nQawt87vkvGtnM598cC3TwRh///2TnL0xfyErNeBP9VzAaXKwxbMx9dnAWIC/+e5xbvT72Luxjk89\nvG5eeKcyD1GN9mWAVRUdeOzVnB68KNfeLmhzk58/t06O0WDgjz64ib0b67jR7+Pv//0k497ZxdgE\nQcDhtOD1Buj3D7K5ej1llplMxbM3xvh//+MUoUiC33tkPfdtnx3eKZhMWBubiPT1IiXU9SJNZ09y\n7V3ol7OSC5qHLPWCypwW/vSj22iodvD80V65xO2cE4wiSPvGhxElkb1pm3xCFPnF6zd56rkruB1m\n/uxjO2irn53NaU3mNuSzwQHsa5Sf1z0qO/4Lm4ceBLMZS33DrL+vbq7gCx/egtEg8LWfnuVXh7vm\n+R6cRR7yWLBgF0WRv/7rv+Y73/kOv/rVr3j22We5cUN9g9lbTSAW5NzoBeqddbS55UV5Y8DLX/3b\nMY5fGmZjRxX/5eM7MpYLSBU+ykNjB9ifXMTT06FUE+t8CPb0yl1yWjIn5Ny7vYnfe2Q9kViCr/z4\nLN9/8eqsRsj25CKOhBLsqt+O2WAiIYocPNXP3/6vmbKrv//ohlmVChUUAZKPI1kQBFlrj8jfq4um\nmqEAmslo4DOPbuCBO5rpHwvwF//6NgdP9c/SWh1Oi6yhSTMCJhiO893nLvOVH59FkuBPPrSZOzc3\nzPt+5blSLEZ0WJt9GaDGXsWaipWMTckRQos1DzXldv6vT+xkXWsFp66N8bf/6wRXe6dSnyuCdGzK\ni0kwckfdtuT/D/H3/36KX77ZRXWZjT/72A5aaufXIzK6XJiqqvMW7Ns9m7EYLYwm5yHfkE8pHic6\n0I+lqRnBOD/BaV1bJX/2sR1UuK387NBN/vt3j+JNc7Ar85CpzEQxkJ+KmMbZs2dpa2ujqUnWYB55\n5BFefvllVuboDH+7cGz4FHEpwd76ndwY8PGrw12phtEfeWAN79nRlNXmayqvwFhenpd9GeSQv3pH\nHVLEgLUy/58ikOzmtFBFxzs3N9BS6+Kbv7zASyf6OHVtjPt2NHHXlobUIjbFrGyr3s6JKyM8/UYn\nfaMBrBYjv//oBvZtyt4I2VqAfRnktPJXzsjOaz00VVuWeTAIAr/zrtU0e1z88JXrPPX8FY5cGOL+\nHc1sXVWN3WkGUaDcUEGdqY3nj/bw/NEepvxRmj1OHn94/YIJSNbWVjgsR6RYG9Vn/yrsbbiD586f\nBPKfB0mSiPT2YK7xYHQ4Ml7jtMlNsb//4lUOnh7gi/9+kl3rannPrhbaG9wYTQKJMGz2bGRiUuSn\np65w5PwQkViC3etr+eSDaxdMQLK2thI4fYq4dwo86uqzKNhMVnZ4tjB8TijIkR4dHJA7aC1QVmJF\nYxn/9VO7+Oenz/PW+SFOXh7hwLYmHtzdoqo2/e1MwYJ9eHiYhoYZDaauro5z584V+rVLxuHXbuK2\n1fLTX0SIhk4AsKa5nA/cvYK772hldHThxtHWllaC58+R8PtVV1RUEASBXVU76ZQgYgzm/W8I3OyS\nx5KjNkprnZu/+NQufn7oJgdP9/OTgzf4+aGbNDsEagFLqIIvfvs6kgQCcNeWBj50zwoqciSJWOrl\nAlD5OMwAKm0VtFrlscfN+dfnCPf2YLDbMdXUZL1GEATu2drI5hXVfO+FK5y6Nsa1Pi9mk4GV9ghu\nrIhDTfyXf5ZzGkxGgQ/e3cF797blbEg+43PpgT3ZSzFkY1vtZl6NyzZ2m4Yqn+nEp6ZITE9jX71m\nwetMRgOffGgd+zc38IOXrnHs8gjHLo9gtRhZb4hgilk5fzzB4SHZgVpdZuV337OG/Zvqc54srS2y\nYI/09sIq9WUdFPY27OTZ2GWwJPJ2pCvm0VzlqxXz1MkbE/zwxSu8eLyXF4/3Umkzsgq41jfIPopD\nSU2nYMGeb5ynR+NOvliU9zVisVThrK6mbUMZ79ndxuZVM4Ih1zgDa1cRPH8O2/QYFR2Zj+gLsT+4\nk05OMMl43nMyeLMTwWikactaDJbcmt7nP7qDx9+/mVeO9/DayT68kWvgr0OYqmJ9exWbV9Zw59ZG\nOhrV1X4BGGxvI9DVTXWFDYM5u1DK9m9cX76Wq/gYZgCPJ3Ps+EIkwmGuDg9TtnEDtbW50/o9Hjd/\n9bkauod8vHlmgDfODBCMD+CmkcRoLVtX13Dn1ib2b26gXGX2Y9yxnj5AGh7I+7f0mGqJAn7HOOs8\nC6+nTM+Y6L4KQNW61arG4PG42bOliWMXhzhxeYSzN4eJ+idwBMqxhl3csb6Sh/a2cceGeowqhaxh\n01omngHT+FDWcS5Edc0WXox3EbIFcFeYsZltuW+aw/SY/Oy6LesoU/H8h+vKeffuVl4+1suxi8Pc\nnOwkOi4QIXbbyCotFCzY6+vrGRiYyXAcHh6mtjZ3NblcmvBS4XY5KBMcfOKTM45TZWwejzvnOMUa\n+eUbOXeZWEO75uf7RuQ42QlxnDOd17KmbWdDEkUC3d2Y6xsY90YA9RrvvnW17F3r4YsHj8BwHfva\nV/LgYzPt77T8RoaGJqTrNxg4dzWrlrTQfNojZYCPs5PnGRrOXP99IUI3roMkYahv1DRuh1Hg3Tua\n2L3RzZd+fhTGG/n9d+1gzUY5SS0aijIaUn8cN9d4mL5xk5ERX14+E1vcSVgI8XLnm7Q5s5/Ass3l\n+DlZ449X1WmahxV1Lvl/66Z58RdhhEAl//UTu1MmoYnxzBFVmYiVy9FCE5ev0Yz2dz0WjSMkjMRM\nYV64eDjl79DC1JVrIAiEnFVEVDzf43EzNRlk56pqdq6q5t8vXeTwwDH+eNunbxtZBeo3yYKdp5s3\nb6anp4f+/n6i0SjPPvss73rXuwr92iXD6ZKTc/I9eaQch3nalxUbXswS4a2hzK3SFiI2MoIYDmsq\nS5pOr7+fgbhc7yYRzj/LrpAIIYBIQN7gvMIklyauar9/AYehGo4NnyJqliNlCnGYWVtaSUxPk/BO\n5b44A2JIQLLEOTt2nmBMe/io1m5BczkyeCxlDss3httUXYPBbi/4nYibI6mINS2k/Ax1dRhs2rX9\nSCLKyZEzVNrKWVe1OvcNtyEFC3aj0cif//mf8+lPf5pHH32URx55pGgcpyB73RMFZJjJHdqteduX\nlZfHZIWjQydJiNpC5RSBls1hmIu3Bk8gGuIYjIU5itK7SuVD+sv81tAJzfcXItglSeLI4HEkc2zW\nWPIhFcedx3qQJEmO5XdZiIlxToyc1vwdkd4ejC43psrsDS6yMRme4vLENdxuuTZOvvOQym0YHiYR\nztzjdyGUd6LM7eCGt5ORYPZEokzEx8cQQ6G834nTI+cIJyLsadiZtarn7Y4uo77nnnt4/vnneeGF\nF/jsZz+rx1cuGWo61C+EYDBgbW6RO7THtH+H8vKsbmhjOurn4sQVTfcXItBiYpzjQ6dwW1w4XbaC\nsuysTc0gCPlvcIEoJrOB2rIazo1eIBDT5kyO9PSA0Sg3QdFIl6+HocAwq+rlzamgeSigxEIkHEcU\nJWoqKhAQNGuriWSbQmtLa15moLcGTyAhsaJWbpBR8AYnSQS7ta8H5bntHvm31DoPhZ7elAYwe+vv\nyHHl7Utxbkc6okdYk7W1FUSRaL/6+t8KyrH/jhbZtq2kcatFrfc/E+fGLhKIB9lVvx2nq7CiRwab\nDXNdHZFe7bVzYKb29r6GO4hLCY4Pq9dWpUSCSF8v1sYmBJN2t5FSUXBfm1yeVw+NPZ/QT+W55W4H\nG6rX0u3rZcA/pPr+mYxT7WtBlETeGjyGxWBmbf0KoHCTFID/Zqfme5WNdVVdG3aTnbcHj2s6yabe\niTzMUWOhCa5O3ZAT5xboRXC7844X7Hp0S0lpaXmYIZTnrqxrodXdxIXxy0xFvKrvj/T2YPXUaA61\nhJkyxXc27sbutCBJEM6Qbq8WW2sbYihEbEzb0VkUJULBKA6XlV11OzAIBt5MK2Obi+jQEFIslteL\nHI6HOT5yhipbJRtqV2O1mQpaC6bKKowud14ae3oxOKWsw9z2gQtRiGC/PtWZKlNcWe6aNZ58UHwu\ngc4uzfcq819WZmdX3Ta80WlNJ9lCNPa3FW29QbvD9naiJNiz9PzUQiGOQ7n9lwmTycj+xt2pgmRq\niHu9JLxTODsy92ZdiLHQOJcnr7GyvJ16Z50uRY9mzBDa5iEciiFJ8m9RbnWzuWYD/f5BeqZzl1eV\nn9clP19D82qFkyNniSai7G24A4NgwOW2FiTYBUHA2tpKbHSURDB3Hfh00ovBba5Zj9Ps4O2hE6q1\n1ZlSAtrnQaluuq9hly6nWKV2TiAfjT2tAJgSEXNk4Jjq+yM9PRjLyzGVqw/XBbmD2uHBY1iNw/az\nyAAAIABJREFUFrZ7Nue+4TamJNh1qAlhaWzKu8FAeu3tO+q2YTaYOTxwdFYv0GwoJwRnR7vm5x5O\nvihK5T5dTi55OlDnNthQxvRG/1tZ70lH2VBteQi0wwPHEBBSdYJcZTbCwRgJFZ12sjErUUkDMxq7\nFZPBxO76HfhjAc6MXVB1f7inB8FiwVKvLWQ2FA9xauQcHns1qyo6sCeTowra4JK1c4Ld3Zpr56QL\n9hZXE02uBs6NX8IXzR12mPD7iU+M56WtX5y4wlTEy676HdhMhXVuutWUBLsOGrvBYsFS30Ckt1dT\ng4FU+6/kGOReq1sYC09wbTJ7aVWFcHdSsK9coWm8CTHBW4PHsJvsbE82ULiVGvvcssXrq1ZTZavk\n+PBpQvHcURWRnm4QhKy1crIxmFamuMomR5G43PILHQ7mb5KaqSFU2DwovQFe7zuS9R4FMRYjOjiA\ntblZdfNqhbcGTxATY+xv2I0gCBiNBmx287ym1lqxtrUhRqNEhwY13Rf0y450s8WIIAjsb1B/kk1F\nieVhlns9qUjcnaEnQ7FREuw6VXGztrbKDQZG1duXQ0nh4XDOZGoq2qrSwWchlKO3a4U2wX5+/BLe\n6DS767enyhTrobGbysowVlRoPrnMbTSS3gv0+PCphW6diVmu1R6zrPgY0rskKYK9kHlImeY0nlzm\ntoOrd9aypmIlV6duMBRYuLBYdKAfEgnNZhhJkni9/wgmwTgrEShTU2utKPMQ6dY+D+lF8XbXb8ds\nMPN6/5GcJ9l87eujgXEujl+ho6x1XpniYuQdL9hNJiMWa2EOM8jPgTpjgpg59q0ob6PeUcupkXN4\nIwsfPSM93Rhdbiw12rz3mRooOJNp84U2FrC1thGfnCQ+rb65daZGI/uUXqD9CztR42NjiMFgqtGF\nWsLxMEcGj1NucbOlZkPq704dBLu5tg7BastbY7enbfR3N8s1Z17PYZbK13F6ZfI6w8FRttduxW2Z\nccA7nBaikQRxFX1Ss2FtawcgnPSBqCEVy59WBM1hdrC7fjvj4UkujF9e8P5wlpLFuXj55htyb9em\n4tfWoSTYAXRpXGtTFnFXl+p7UppqmkATBIEDzXeSkBK80Z/9CJ4IBOSY5bY2TTHLI8HRlGbS5Jqp\nRTLjMCvw+J2HGSJTa8ByaxmbazbQ5x9I9QLNhCI0tEbEvDV0gnAizN1N+zAZZkIkXW7brDHlg2Aw\nYG1J5jZE1X9PwB9JOdIVttZspMzi5u2hE0QS2b8rX8fpoeQaO9A8u2iZLj6X5hbZ96RBY1cc6XPL\n9R5ovhOAg71vLnh/pLtbDr1VUdZEISEmePnmYewmOzuz9HYtNkqCHXkRh0M6Ocw0LOJsLfH2NOzE\nbrJzqP8IsURmW2+mLjlqeLVX1kzua7l71t9TDrMCN7h8en9mm4d7mmRh82rv61nvjeQRsyxKIq/1\nvYlJMHLXHA3NVVa4xg7JVnnJ3qNqCQWiqYQ5BaNBjpYKxcOcWCC2P9zTI/sZmptVP28yPMXZ0Qu0\nuBppL5ut4ephojRYrdibGjU1t86k7AA0uRpYVSF3VxoKjGS8VwyHiQ4NYm1t0+RnOD16Dm/Yx976\nnRl7uxYjJcGOPkX1jQ4H5to6wt1dquOvszWxthot3NW4B38skDVRJ9zdBYBNQ4ifPxbgyOBxqmyV\nbJvT29VoNGBzmAno4GsArSappAliTqnatZWraHY1cnLkLGOhiYz3ztRGUX/0vjRxjZHgGDvrts0y\nP0CaYC/UcagxQkh2pMczNpa4q3EPAgIH+97MuLYkUSTS24uloUFVdU+FN/rfQkLinub98059ejWa\ncK1ckWxunVkYz2WhXqeK1n4oy0k20tsjN5xJnp7VIEkSL/a8hoDAPc3aSy3frpQEO/o5UG1tbXJz\n67HMfRTnEsiiqQIcaN6PQTDwat8bGV/mGYHWrnp8b/S/TUyMcV/znRmrJzqdloJfZHONB4PDkYrY\nUUMwEMXuNGOYo2UJgsC7Wu9BQuKVLFp7uKcHU2UVJnfuUr0KB/veAODepKBIJ2WKKXiD09YPN7TA\nWqi0VbCzbiv9/kHOj1+a93lsZBgpEtZkhgnHw7ze/xYOkz3VJSkdvRpNOJOOfbV29oUE+9aajZRb\nynh78HjGaKl8lJ0rk9fpne5nT/N2ah2e3DcUCSXBjj72REhzFiUXWC5CSU3VmaHed6Wtgu2ezfT7\nB7k2Nb/VYKS7G4PdPqt59ULExDiv9b2JzWhjX2PmeucOl+wwixXgMBMEAVtbO7HhIRJBdfVeFmpi\nvbN2K5XWCo4MHMUfm53wIzevntKkrQ/4h7g4foUV5e20ls03WzidFgRBB5NUY5Pc3FqlSWohgQbw\nnrb7AHi+65V5G324S04CsmlIVDvUf4RAPMj9LXdnND/oEQYMssYO6k2UC82D0WDkQPN+wolIRlv7\njGBvVz2+F7pfBeD969+j+p5ioCTY0U+w2zQK9kAggsEgYLVlrm9yX8tdAPxmzssshsNEh4dkW6JK\nx+nx4dP4otNy+QBT5rBAvY7fyganRluNRePEogkcWZpZGA1G7mu5i6gY4/W+2ZEh+djXn+18AYD3\ntN2b8XPBIMz0Pi0AwWTC2tSsurl1LsHe5Gpgc80GOn098zZ6xWFva1Mn2COJKC/3HMJmtKXMG3PR\n6xSrJM+p3uCy2NgVDjTvx2ly8HLvoXlljSPd3QhWG+Y6dQla3b5erkxeZ23lKlZW5Vfm+HalJNjR\nJzkH0h2oXaquV7JOswnnjvI2NlSt5erkdS5PXEv9PdIrN69Wm4QRS8T4deeLmAQj97ZkfpFhZh4K\nFWqK5hjuzJ1OHgwosfzZbcPKZnSw7w3CaUfwlIamUmPv8fVxevQ87WWtbKpen/U6R4EF0RSsrW1y\nc+vB3MXhsvlb0nmw7X4Anu96ddbfw12dsuNU5Ty80f8W/liA+1ruxGG2Z7xGL43d5HTKvqcedb6n\nXPNgM9l4oPUAoXiIV5MmNQAxEiE6OICttVW14/TF7oPAzGloOVES7OinsRudTswejyoHaqZ43Uy8\nb+V7AXj6xq9TyRlhjbVRXus/zER4kgPNd6YyLDOhxNMXHPrZnhTs3SoE+5xyAhm/z2Tj/pa78ccC\nPN89I9RmTBDqErSeufk8AI+teHDBk47DaSURF4lG8jdJyePqmDXOhcgWGZROR3kraytXcXnyGlfH\n5MxkKZEg0tONpbEJgzV3Gnw0EePFnoNYjZZ5kVHpWG0mDEZBl2bO1tY2xECA+MR4zmuV9ZDJiaxw\nT/N+XGYnr/a+ntLatTpOu329nB49T6u7ibWVq1TdU0yUBDv6aewgmyHEQID4+MIO1Eg4jpiQcgr2\nFncjd9Rto9c/wMmRs/K93eodp/5YgOe6XsZhsvNQ+/0LXjtz/C4sIsRUVS1XOFQR069GoAE80HqA\nSmsFr/S+zlhoAkmSCHfexFhRgakid1OJ61OdXJy4wpqKlTm74ug1D6kNrjN3eYhcphiFhzveDcC/\nnvwhoiQSHRpEikZTz8rFq72vMx31c6D5TpxmR9brBEE2Sekh2BVnphqHeiAQxeYwY1ygcbjNZE1q\n7WFe6T2U/O6u5LPacz5DlER+dPVpJCQ+uOqRvGrX3+6UBDtgs5sRhMJty6Dezp7LlpjOYysexCgY\neebm88TFOOGebtXFnp7rfJlQPMx729+FY4EXGfQ7uQiCgLW9ndjYKAn/wr0y1ZggACxGCx9Y9TBx\nMc7Prz9LfHKShNerSlsXJZFfXP81AI+tfDDn9XqZIaxNzQgmkzqTlMr1sKqig931O7g52cOhviMz\np5b29pzPGA6O8uuul3CbXTzQeiDn9Ypg18MkBRBRc3LxR3HmWAsga+1ui4sXe15jKDCcMn+q0djf\nGjxOl6+HnbVbWbMMtXUoCXZgRjsp1LYMaY7DHNrJjKaa+/hcY6/m7qa9jIXGefbys0T7+7C1tee0\nJfb7BznUf4QaWxV3N+/P+Rw9Ty6KoMm5wanUVEGOkFlR3s7p0XN0npdjme0qBPvzXa/S6etmR+0W\nVpS357xeL1+DYDJhbWsn0tebMwM1FIgiCLKSkYsPrXoUp8XBMzefw3dD7g9rzeE4FSWRf7/0E+Ji\nnI+s/cCC2rqCw2VBTEhEwvm1jVRIKTs5NrhYNJF0pOdeC1ajhY+u+SBxMc5TF39EuKsLwWrNqewE\nY0GevvEbLEYLH1z1iOp/Q7FREuxJHAU2tVZIFYDKqbHnti2n89iKB6m113DhzKuy4zRH4a9ALMi3\nzn6XhJTgw2veh9mQu7OQXho7gK09Gb+cQ0tTa4oBeQP+8OrHEBA4d+ol+Tk5BHunt5tfd71IhbWc\nj679kJqhF9wuMR1be4ecgZqjMJrib1FjFnBbXHx8ywcJJyIMXz0jtwTMUdnyzYG3ueHtZKtnk+pa\n44rSESgwWcvocmGuqyfcdXPBDFTF9KVG2QHYVruZXXU76J/sITI4ILcEXEDZkSSJH1/7Jf5YgIfb\nH6DSVqHtH1JEFCTYn3vuOR599FHWr1/PhQvqakbfrjicFuJxkVi0MIeZ0eXCVFOT04G6UHJSJmwm\nG7+36XdpnJDHF2/OrpmIksi/XfgBY+EJHmq7n81pRa4WwmwxYjIb9NXYcwl2laYYhbayFt634iEq\nRmQTj9CcvRJfOB7m3y78AEmS+E8bPqpKS4UZwVKojR3SI4Sy29klSSLojy7oMJzL/Sv2s8rVimPE\nR6imbMGWgDemuvj59Wexm2z8b2s+oNqmrOcGZ1+xEjEUWrCEb0CDeVLhI2veR7vfgiBJRBqqFrz2\nmZvPc3ToJK3uplQo8XKlIMG+Zs0avv71r7NrV3G3kQL9Mu1A1lZFv3/BEr4zyUnqF3Gzu5GdIbmS\n43+E3s7YvT0hJvjZtV9xceIKG6rX8sgK9YkXejrMTBWVGMsrcjpQlSbWFqv6XqUPtNxDw6TERJmR\n73U9k7HD0Hhokq+c+iZj4Qne3XYvaypXqv5+XU8uyRPFQoI9GkkQj4ua1oJBMPCJqvswiXDdHeLZ\nzhczXnd54hpfP/0vxMQ4v7vutym3qs/Q1cskBaROmOGb2edB2UDU2NgVHGYHDxnWAfBi4irnxi5m\nvO5g75s83/0KHns1f7T192YVfluOFCTYV6xYQXt7e8Hmi9sBPe3L9lWyQyZ841rWawIabMsKkiTh\nGJwk6rJxnXH++7Gv8ubA28QSMSRJotPbzd8f/0de7XsDj72axzf8DgZB20+smKREsfDf1NbeTnxy\ngrh3Kus1akI+5xIfGcYUjROsr+T06Dn+7uiXOTt6QY6UiYc5N3aRvz/+VXqn+9nXsItHO7RlFerl\nPAW5hK/B4Vjw5JIyy6k0QSiYBuSNPVBXwW+6XuKpiz+k0ys3E58IT/JKzyG+ceZfEZH47OZPsq1W\nW7s3p1OfujkAthXyxhq+OT+LWkFLQEE65UNyiejBWgvfPPtdXuh+lfHQJJIk0Tc9wJMXvs9Prv0S\nt8XFn2z7zLz6QMuR5b1taSC1iHXQ0uwrZcEeun6dsn2ZE4JSha80CLX4xAQJr5eqHTt5fOPd/MeV\nn/H9yz/l+5d/ikEwpOLc72zczftXPpwzCiYTDqc11dRaq8Cdi629g8CZ04S7unBtnV+PRBQlQoEo\ndU3qtUiYMe9s2v4uhhoDHB44xjfPfReb0Uo4IQsho2Dkd9Z+iDsb92gOZzOaDHJTax0EuyAI2No7\nCF68QMLvz9h0PJDH6Q1InYYevPPjdE78hreHTvD20AncZhfTMdlUZTGY+YMtn8oZ4pkJXcOAm5oR\nzGbCndkFeyCPDU6SJELXr2EsL+f37v5j/vncv/H0jd/w9I3f4DQ5CMTlshaNznr+04aPUmPX1rug\nWMkp2B9//HHGMhS1euKJJ7j//oXjohfC43Hnfe9iUN8oCxdBmj22fMYpVmygz2Ih1n0z6/3RcByH\n00J9vfqGu2NXzwFQvXkDWzfdza6Ojfzowq+YCE4RSUSxGM18eOPDrPdof4kVqmuc3LwyitVsKvg3\nMm3byPjTP8cw1IvnATkZJv07/dMRJAkqq5yanjU9JJfCbb1jO19Ys5rf8j3ED889Q59vkFpnNR5H\nNfd27GNVdXte4/Z43JRV2Jn2hnVZp8GN6whevIB1apjKjoZ5nw/2eAGoayjT9LxY900MFgsb9uzm\nHw17OTt8iYOdR7gwcpXtDRvZ0bCZXU1bqXLk5yS0W+UInXhMLGgelHuHV6/Cd/kKVS4TRvv8jFcx\nLp8SW1orqax2qvruyOgoiakpqvbuYf2qjaxs/L95vfsoNya6uTnZTXtVM4+tfTfbGzbm3OBvN5lU\nCDkF+5NPPrkoDx4dzd2YdimJJ731I8PTqbF5PO68x2ltayd4/RpDPSMZF7HPG8JVZtP0/aOnzgOQ\nqGtO3mfmtzs+OG+chcytYJQXf3/fJEZLYUFTiepGEATGz5zH8eD0vHGODcv/bTQZNI158uIVMBoJ\nuqoJj05jxcUn1/zO7IvE/OZBGaPVZmJ0KMbgwBQm8/xKmFoQa5sAGD59gXjzfFv/0IAs2EVJUj3m\nSrtAsKcX+9p1jE/K2ZdNplY+vroV0vb1RABGA/mtB1GUEASYnAjkvabSf3NjcxtcvETf8XM41s0v\n6TAxLhd5C0diqp83ffQMAIaW9uQ9RvbX7GN/zewSvGNjC+dTFPKuLyVqNx/dwh2L3c7u1Cm0S8G2\nchUksyPnEo8liEYSmk0doZs3wGDQVL1OK3ral40OB9bmFsKdNxFj8xuGaAl1VJDicSK9PVhbWjGY\nc8d858tSOlDzsS37Ll8BScK+Kv/TWS4MSkG06cLnANLs7FnmIdVBSsNGGrpxHZgxf5aQKUiwv/TS\nSxw4cIAzZ87wuc99js985jN6jWvJ0VOgAakXLpxceOnkLdB6urE2NauqCZIvelX1U7CvXo0Ui2Us\njKY11BFk+7oUj2PX2MBbK3rOg6miAlNVNaEb1zPGcSvKRKbyzdnwXZCjP+yr1xQ8voXQqyAazETG\nhLI4UIP++R2kchG6cT2ZCLa8qjMWSkHO0wceeIAHHnhAr7HcUowmAza7WZfQLgDbSlk7CV2fHxmT\nj0CL9PUixWIprWex0H2DW72WqVdeJnTtKuzbMeuzlNPQrX4eglfkZsb2Net0GV829HQcAtjXrmX6\nyGGiA/1yL9A0gn456zS9iXUufJcugyBgX7nI68FlZXTITzQSx2or7IRkqqzCWFFB+OYNJEmaZfNO\nxEUi4Tg1deojVsRIhEhPN7aOFRjMy6OlnV6UMk/TcLosuoR2AZjcZZjr6uRFPEdLyycRQ9FycmWc\nFopTx9hlkDV2QBbsc8hHUw1dvSJ/75q1OowuO8qY9BLsjrXyRhRMjj+dgD+C3WGZ10EqG2Isiv/a\ndaytbRhsmcvu6oWe60EQBOwdK0l4vfMqPeZzig13dYIolswwGSgJ9jQcbqvcQShaWG0MBfuKVXK2\n3eDsbDslo1GTQFM01UW0qQLYHMkOQjpkXYKcqGT2eAhdn2+GCE5re5mleJzQtatYGpswlWkLkdSK\ncnIJ6DQP9qRgV35HBSXrVJNA60yao1Yv7lqAxTjBJTf6K7M3uFSoo1P9O6GYOW2LfGopRkqCPQ29\ntVVbMlEpNCdRSUvhK5CbFQcvX8JUVY25tk6XsWXDYBCwOy26vcgA9lVrEIMBgr19s/4e8EcwGAVV\nha9Arr8jRaPY1y6utg76m2LMNR5MlVWErlyZZa/OJ+s0nDTvLbZ9HcDp1i9JCcCxXi5vEbw0O0M0\nmEcsf8lxmp2SYE9D7+O3suDC1+YIdo2mmEhPD2IggGPDhiWpHe10yZUu9Yp0UgSQ7+Lslzngj+J0\nWVX/mxRtVzFrLCZ6a6qCIGBfu5aEf5rowExHpXyyToNXZbOWfdXiC/bUyUWnebA0NWN0uwlcujBr\nfWk1xUiiSOjGdUzV1arq8b/TKAn2NGZqY+ijnVgam+RFfPH87EWs0XmqaDeKtrPYOF3WlDNLD5Tj\nt+/ijBlCFCWC/ogmDW2pHKdAMuzOoFt0EMxsSKErl1J/05p1Koki4RvXsDU2YCpXn9yWL3qfXASD\nAce69SSmpoilFQTT+k5EeroR/f4leyeKjZJgTyNlitEpblcwGHBs3ETC6yXa15v6e9AvF74yW9TF\n6wYvyZUzHeuWSLAnj9+BaX02OHN9A0aXG9/FGYEWDkaRJPWaqhSPE7p+DUtD46Lb1xWcLqu+Jqk1\n8x2oWjX2aH8fYihE2frsPVv1xKljpUsFx/qNAATSzDFaywkEzstZ2M5N2urfvFMoCfY0UuVaddLY\nYWbhKQsRwO+PqDZBiLGoLNCampdEQwP9fQ2CIGBfs4bo2BjRoaFZ36021DHc3YUUiaSckEuBw2kh\nFNSnIBqAubYWU2XlLDu7Vo1dOb2VbVwawa6EYOql7EBmO7tWG3vg/DkQhNQmUWI2JcGeht4CDcCx\ncRMIQkqwJ+Ii4WAspRXnInzjBlI0uqRHTr01dgDnlq0A+M+ckr9bY6ijEua4FPZ1BYfLgiRBKKjn\nBreOxLQvFSml1d/iP30KBIHKnTtyX6wDBoMBu9Osq0nK7PHIkVKXLyEl5JLLWk6xiUCA8I3r2Fas\nxOhUV1PmnUZJsKeRqsmuo8ZucpdhbWsndP0aYjg0I9BUaqpLbV8H/TrnpOPcvFXe4M6clr97WqOm\nelk24yx2/Ho6etuXIS2e/bL8u2rZ4BJ+P6Hr17CtWImlYum6/zhdVgL+iK5lQxzrNyCGQqkG14FA\nRHUHqeClCyBJODdv0W08y42SYE/DaJS1Ez01dkiaYxIJgpcupR291WmqwUsXwGDAsQQhfgrKpqPn\nPJjKy3GvWU3o+jUSfr8mm2oiECB4+RLW1rYlM0dBWv0gHU8ujqRpzn/yhPzdGrJOA+fPgihmLIG8\nmDhcFuKxwruLzfrOpAkleOkCoigSCsRK9nUdKQn2OSyGdpJuZ1eEhBpTTCIYINzZKadML3KGYTqu\nRTDFAFTt3gWiSODc2Rmbqop58J8+CYkE7juWtlNXyiSl48nFXFWFbeUqQlcuE/f5CGrIOvWflk87\nzq3bdRuPGmYK5OnoSF6XPLlcukgoEEs+J/fpTZIkAufPYXS5sbaW6sNkoyTY5+BcBO3E1rECg8NB\n4MI5/Elh6VIj0E6evCVHTovVhNFk0NUkBVC1+w5AtrPPmCByv8z+48cAcO1cWsGu/EZ+nTc4985d\nIElMnzyhOutUiscJnj+L2ePB0pi9z+ti4FgkE6WtYwWhq1fwDcvlBdTMQ7S/j8TUFI6NmxZsXP1O\npzQzc1gM+7JgNOLYsJH42BjTQxOAOk3Vd+RNAMr27Mtxpb4IgiAnKekYCQFgb2nB7PEQPH+OgC+C\n2WLM2es0EQwQuHgBa0srlrrFzbqdy4wTWd95cN0hb3BTx0+qzjoNXrmMGA7j3Lp9SZLU0tGz92k6\n7n37QRQZPymH86oxTwbOlcwwaigJ9jnoHcuu4Eoen729Q7Oek43Y+BihK5exr1mL2ePRdSxqcLqt\nBANREon5ZWbzRRAEnFu3I4bDBLxBVRqa/9QpSCRwLbEZBtJ8DTpr7OaqamwrVjJ1U85tUGNbDiSj\niVzbltYMA/pnZCuU7doDRiMTV+X67K6yhedBkiR8bx0GoxHHpk26jmW5URLsc9C7NoaCa+cdGBxO\npsenEYTcx07fW0cAKNu7X9dxqEV5mUM6hrmBLJhEDISjkioNzX9CNsMstX0dwGQyYrObdDfFgLwe\nIkbZb5Jrk5dEEf/p0xgcjkUvApeJmdr0+s6D0e3GuXkLfp86v1P4+jWi/X24tu/E5F6aJLVipSTY\n57BYx06DxULZnXcRESzYzCzoLJMkCd+RNxFMpluiqcLiRMaAXJ0yXlkLgMO+cMxyIhggcOE81pYW\nLHX1uo5DLU63VXeNHeSNShHsuTT2wNkzxCfGcW3fiWBa+v7zi3WKBSjbt5+ISW66nsvvNHXwFQAq\n7r1P93EsN0qCfQ56t8hLp/yee4kYnViiC/dfjHR1EhsawrV9B0aHQ/dxqGExQv0ABJMJy54DAJgm\nBhe8dvrYMdkMs8RO03ScbiuxaIJoRJ+6OQrm6hrE2mYArGSfY0mSmPj1rwCofM9Duo5BLQ6XFUEA\n/3RY9+92btlGxCr38XQ4sod8xqd9+E8cx9LQuKTZx8VKSbDPYTGSUhSkihpEgxGzf4LIQH/W6xSn\nqXvfrTHDwOKE+ikIa2THl3TzEmI08zyLkQgTv3oawWymbP9duo9BLYsV+gkgJRtbx46+kfWa0LWr\nhG/ewLltO9amJt3HoAaDQcDhshLw6T8HBrOZmKMKczxE5NrlrNf53ngdKR6n/MB9S+48LkZKgn0O\n9mSjCb1NEEDKlmiNB/AefDXjNdGhQbyvH8JYXoFzw61zEC3m8Tuc/EqzfwLf4cxCbfKlF4hPTlL5\n7gcxV1XpPga1KCeXxbCzx9zVAMRPHSHS25PxmolfPwtA1Xsf0f35WnCVWQn49auboyBJEiEs2OIB\nxn72E6T4/JORJIp4XzuIYLFQtv/WKTvFREGC/R/+4R9473vfy/vf/34+//nP4/cvbGIoBpTO7Ho7\nT2FG+7WbJbxvHCLS2zvrc0kUGXryO0ixGLUf+/gtsacqLKbGrmyaNqJMPv+bVL0QhbjPx+RvnsXo\nclP50MO6P18Li1E3R8E/HUEQwBIPMfqTH837PNLbQ/D8Wexr1t7yZhIutxVRlHR3pkfCcRIJCWeZ\njUh3FxPP/XreNVMvv0hsbBT37r0YHaXaMGooSLDfddddPPvsszz99NO0tbXxzW9+U69x3VIcLquu\njSYUFOHg2b0NKRql/2tfJj41lfp88sXnCd+4jnvXbjmJ5RaSciIvgkBTNovqbZuIjY4y+sMfzGqb\nN/7M04jhMFXve/8t8zEoLKZgD/giuMpsODdsIHjh/KwKoHHvFEPffRK49do6LF6yljJpmH0fAAAa\nGklEQVSvVWtXYKyoYPyZp4mklbj2nznN6I/+A2N5OdXve7+uz17OFCTY9+/fn4ru2LZtG0PJkqzF\njtNlIREXCQVjun6vsohrNq+j5kMfJj4xQf/XvkLg/Dkmnv8N47/4GUa3m9qPfULX5+aDEuq3GCYp\nZR4aH3svloZGpl55iYGvfYXg1Sv0f+0reF99GXNdHRX33Kv7s7WSEmg6z0MiIRLwR3G5rdR8+CMg\nCAz809cY/fF/ELx0kZ6//SsiXZ2U7b8zVV/mVuJMxpj7dbazKxuFu8pJ3Sc/BYkEg//yTbyHXmP6\nxDEGv/UNBLOZpj/5Auaqal2fvZzR7az/k5/8hEceufWahR64ymwA+KZCGC36uSHSC4BVvPcRosPD\n+N58nf6vfEm+QBCo/cSnMLrduj2zEBwuK36f/pEQQX8Um92EraaKlv/z/2Hwm/9E4NxZAufOAnIr\nvdqPfeKWmqIUUmGfOgs0xTnvKrNia22j/vd+n7Gf/pjJ559j8vnnAKj+4G9R9fCjt4WzcLGcyOm1\nk1ybtlF+zwG8h15j+KknU9c0/OGfYOtYoetzlzs535zHH3+csbGxeX9/4oknuP/++wH4xje+gdls\n5rHHHlP9YI/n9hBemahvLOP8yX68UyHWbtQvfjoWkW3JbR3VWG1map74Y3qb5DR5Z1srrlUrsdXn\n97zFmM/KagcTowHKy+w5U//V4vG4CQailFfak2N2U/fXf0H3975P4GYnTR98P+Vbt9xSYZY+l5Ik\nYbYYiYRius5xKOmU9tSV4fG48Tz2IB0P3sfwiy8x+tobNH3wfVTv26t6nItNJCg7NRNxUfNzF7pe\nTMjmzqaWSjweNzX/+fP4H32IYE8Pwd4+XCtX4jlwd/4D12mcxUbOt/XJJ59c8POf//znvPbaazz1\n1FOaHjw6Oq3p+qVEMMpCxTcZ0nWck+MBzBYjvukwJGOCHe95FAAJmAam83iex+NelPlUmh50dY5T\nWV24rdvjcTPQP0UkHMdqM80as/PhD+AEYsDY2K1zwmeaS4fLwtSUvmuht2cSAKNZmPW9pt1307D7\nbkQWfkcW6zfPRizp4B4dntb03FzjHB2SP4snEjPXVTVgqGrAtW2PfM0S/DuXej7zRe3mU5Cd4dCh\nQ3z729/mG9/4BhaL+qbEtzvKsdM7FdL1ewMamzffapyL0CpwOmnaUcxdxYDLbSUcjJGI61c3J6Ch\nyuftgMNpwWAQdDfF+DWUsS6hnoLO13/zN39DLBbj05/+NABbt27lL//yL/UY1y1FKUbkndRPsMfj\nCcKhONW1Lt2+c7FZjIgQxWbvLi8ewZ6ejVxWoU9dfH9qgysOgSYnKVkWJSrGZjdhNqtr7F5CHQUJ\n9hdeeEGvcdxWKCnUemrsWhpL3C4ojkM9X+Zpb1JTLRKBBmkRIdN6CnZlHopng3O5rQwP+BBFCYNB\nHx+IPKfFMwfFQinzNAMGg4DLbcWno8ZejEdOd1Lo6Bnipphi3MUk0Bahbo5/OoLJbMBqu/WRP2px\nlVmRJHRrbB2NxIlFE0VjjiomSoI9C64yG9O+sG71yFM2VZV9HW8HFG1y2qtfyGNRmmIWySTlcltv\ni1BGtSjzoFcIbDEqO8VCSbBnwVWe1E50SkyZidctHuep1WbCYjWltGw9mPZGVNWjv53Q2yQVi8n+\nlmIywwC43PJ49drgis2BXEyUBHsWUtqqTkJN0XqLSVMFKCu3Me0N61Zewe8L43RbMRqLZ+nNJOfo\nu8kXm0Bz6Zx9qrbBRgntFM/btcS4dV7ExSrYXeVW4jGRcKjw8gpiQiQwHSk6TdWuc6hfSqAVkQMZ\n0kwxemvsRTYPxUBJsGfBlXIc6qOx+7xhLFYjVlv2ZgK3I8pGpMcG5/OGkaSZTbNYUJp769Vowl+E\nDmSYEcC6bXAlG/uiURLsWdDz2ClJEtPeMGXl+oTKLSXuVN2cwoWaEj7qKrJTC8gbXGA6qkuS0kyo\nY3EJNLtDPrnodYpN+Z2KKKCgWCgJ9iwojiI9NPZwKEY8JhadGQbSNXYdBHsyfLTYNHYgFb+uh8/F\nX6Q2doNB55PLdASL1ahbHaISM5QEexasNhNWm4lpHbSTYrWvw8yY9Qh59E4GgeJKylFwV+h3cim2\nrNN0nGU2gv4ooljYyUWSJDnkswjXQjFQEuwLUF5h10VTTQn2Isyw01ewh2Z9ZzFRVj5TyrlQ/NMR\nrDYTZkvxaaoutz5hwJFwnGgkUco6XSRKgn0ByirtRCMJIuHCOtQrWl4xCjRZABl1MUEUsynGrZhi\nCtzgZE01UnRmGAW9fE/KWtCrREOJ2ZQE+wKUJxddoTZFRRiUFaFgFwQBV5lVN429WDXVMp1MMak0\n+iLc3CDNmV7gelBOPiWNfXEoCfYFKK9MCvYCtZNitrGDvCEVenKRJAnvVKho58DhtGA0GZj2FmaK\nmYlhL855KEu+E4XWUVI2yJLGvjiUBPsCpDT2As0QPm84lZ5fjLh0iIwJh2JFrakKgoC73Fawxp4K\ndSxSU0x5pbwWCq18OqOxlwT7YlAS7AugaCeFRMYoMezFqqmCPsdvRaAVW1JOOmXlNiLheEEnF8W2\nrJwGiw1XmQ1B0FFjL+L34namJNgXQA+NPZTsvFPMtsRULHsBgl0xRxVzeJvyGxZijil2wW40GnCX\n23TR2F1lVoymkghaDEqzugDu8qR2UsDxu9jt66BPyGOqDnt5cZogANzJzOFC1oMSy1+sgh3ksYcC\nMaKR/E4uibiI3xcpaeuLSEmwL4DRaKCswo53In/tRLElLgvBXsDJxZ/snFTM86BHZIx3MoTdaS5a\nfwvM2MXznQdlHZXs64tHSbDnoKLKTjgUy7u64XLQ2O0OczIiJH9fQzE2sZ7LzMklv40+kRCZ9oYp\nr3ToOawlRzlt5NsTuBTquPgUpDZ89atf5eWXX8ZgMFBdXc0Xv/hFPB6PXmO7LSivcsCNCbyTIWx2\n7ZUZZ2LYi1c7EQQBd4Gx7FMTQSxWE3ZHcVW3TCelqeY5D9PJ6pbFbIaBdI09T8E+mXwninwebmcK\n0tg/85nP8Mtf/pJf/OIX3HvvvXz961/Xa1y3DRVV8uKbGg/mdf+Mxl68tmWQtVUlZFEroijhnQxR\nU+sqqlZwc0nVD8rTBKGY9IpesCshjwVr7MU9D7czBQl2p9OZ+u9QKITBsPwsOxVV8rF5ajI/we7z\nhrHZzUWZbZmOklKfz8s87Q0jJiSqa525L77NcZfbknXltXeUUtaQoiwUKwVr7KnkpJIpZrEoWNp8\n+ctf5umnn8btdvPUU0/pMabbivKkYM/HgSpJEn5vmOpal97DWnIqq+V5mBwPUFOn7d+jnHZqlsE8\nlFXYGBv2EwxENdcR9y2T+ihmsxGny5J3LLtvKoTZYszLtFlCHTkF++OPP87Y2Ni8vz/xxBPcf//9\nPPHEEzzxxBN861vf4nvf+x6f//znVT3Y43FrH+0toL2jGrPFiN8X0TxmnzdEIiFRU+ta9H/vYn9/\n+4oa3uQ60VBC87OuXxgBWJJ50IOFxljXUM7NK2MYMWj+twT9sgN+5eparLbCT3C3ci6ra130dE5Q\nWenAZDIueG36OCVJwucNU1XjpLa2bLGHqYliWJtqybm6nnzySVVf9Oijj/IHf/AHqgX76Oi0qutu\nJR6Pm7ExP+UVdsZH/YyM+DTZiHs7JwCwuyyL+u/1eNyLPp8Gs/zv7uuZ1Pysvp5JAKprF3+chZJr\nLk0W2dzY0zWOzaVN4xwdnsbhtOCbDkGB07AUv/lCOJwWkODm9bHUaS4Tc8cZDESJRRM4Fvmd0Mqt\nnk+1qN18CjKKd3d3p/775ZdfZsWKFYV83W1LeZWdeEzU3OtxYjQAQLWn+G3LTpcFs8XI5HhA871T\n40EEAapqijvMD/KPZU8kRPy+cNE7ThXyLQZWcpwuDQWdB7/0pS/R2dmJwWCgsbGR//bf/pte47qt\nSDlQJ0Ka4rAnxmQhWFlT/IJdEAQqaxyMDfkRRVGTo3xyIoi73JbzyF4MKGtB6wbnm1oeoY4KqVh2\njQ7UkuN0aShIsP/jP/6jXuO4rSlXQh4ngjS3V6q+b2IsgMEgLJuXubLaycjANN7J8ILH73TCoRjh\nYIy6huVhv3SX2zBbjIyPaBPsqVICRR4Ro1Be0thva5ZffOIiUJFHZIwkSUyOBamodmA0Lo9pTkXG\njKkXalMTyRA/lRvB7Y4gCFTXOpmaCBKPq4/pXy4x7AqKxq1VY1fWTrGHfN7uLA+Js8ikkpQ0xLL7\nfRFi0QRVy8AMo1BZo5gh1M+DEuqobI7LgSqPC0mCyTH186AIwOUi2K02M1abSXNew9iwH4vVVNQl\nNoqBkmBXgdVmxuYwa9LYFcfpcnAYKlRWy5uUFvuysgksF40dZpzhym+shuWmsYN8gvNNhojH1J1c\nYtEEUxMhauqKOwO5GCgJdpVUVNnxTYVIJERV1yuO06plEBGj4C63YTQZNGmqisau1iZfDCgJZ+Oj\nftX3eCdDOFyWos9ATsdT70aSYFzlBqfM13JIVLvdKQl2lVRUOpAk9WFuyykiRsFgEKiosjM1HlSd\nUj81EcRqMy2rLEPFvKbWgRoJx5n2qnc4Fws19bJDfHRIXfz32LAs2Ks1Zi6X0E5JsKskPTJGDROj\nAYwmw7Lz/lfWOInHRVWVHhMJEd9UmIpqx7I6elttJtxlVtWmGEXw1TbcXpmWheKplwW0WsE+PlLS\n2JeKkmBXiWJSGVOxiEVRYmo8SGW1A4Nh+Qg0SK8Zk3uD802FEEWJymXkOFWoqnURDEQJBaM5rx0Z\n9AFQu0xCPhUqqx2YTAZNGrvBIKSc8CUWj5JgV0ldo6xtDfX7cl477Q0Rj4vLKiJGIeVAVWFnV65Z\nTo5TBcWBqsYcMzwgrxllDS0XDAYD1XUuJsdyh36Kosj4aICqGueyCf+9nSnNsErsDgsVVXaGB3yI\n4sL25YlRWaAtJ8epwkzIo3qBphzZlxNqHaiSJDEyMI3TbcHpLu6a/Jnw1LkRRSnnBuedCJGIiyX7\n+hJREuwaqG8qJxZN5EzQmXGcLj9NtbzSjsEgpOylCzHQO4XBIFDXWL4EI1taUiGPOQRaYDpCMBBd\ndvZ1BbV29jHFvl4S7EtCSbBroK5ZMcd4F7wuFeq4DE0xRqOB2sYyxob9RMLZ+8DGonHGhvx46t2Y\nLcVfI2Yu5VV2jEYhZ6jfyKDiOF1e9nUFj8rIGCUipuQ4XRpKgl0D9U2y5jnUl93OLkkSw33eZZ1d\n19xWgSTBQM9U1muG+mWTVUPL8tPWQbYvV9Y4mRgLLGiaW672dYXKGtmBOja08AkuFepYEuxLQkmw\na6Cy2oHFalpQY58cDzLti9C6onJZhfiloxRC6+uazHrNYK88R40tFUsypltBtcdJIi4u6G9QNHZF\ns11uGAwGqmtdTIwFsjpQJUlibMSPu9ymS4ORErkpCXYNCIJAfVMZvqkwwUDmMLeeG+MAtKyoXsqh\nLSm1jWWYzAb6urNr7AO98mf1zctTYwdobJM3uO7r4xk/F0WJ0aFpKmtkhWC54ql3IYpS1rj+oD9K\nOBgr2deXkJJg10h9k3ykHs4S9thzU+6a1LqiasnGtNQYjQYaWyuYGg/iz9B8JB5PMDLgo6bOtaw1\ntPZV1QgCdF6d3zoS5HIKsWhi2TpOFXLZ2ZV3Yrmao25HSoJdI3WKnT2DOSYaiTPY68VT75Jbhy1j\nmpPaan8Gc8zIwDSJxPK1ryvY7GYaWysYGZzOuMEp9vXl6jhV8CT/fdlMc9cuDgOwan3tko3pnU5J\nsGukrtGNIGROVOrrmkQUJVqXsRlGoSkp2Pu657/Mg0kzzHK2ryt0rKkBoCuD1t6f7PW63DXVqhon\nVR4nXdfG55kofd4Q/d1T1DeXL9tggtuRkmDXiNliorrWxeigj3BodrhfygyzcvmaYRSqa53YHGb6\nuybnFQQbSDpOl7vGDtCxWhbsN6+Ozvq7fzrCjUujVFTZl71tWRAENmxrQBQlrpwbmvXZhdMDAKze\nUNLWlxJdBPt3vvMd1q1bx9RUdmfacmLNxjoSCYmTR2aaeUuSRM/NcWx207K3qYL8Mje3VRDwR2cV\nRgsGogz1e6msdmB3LG9zFICrzEZtg5uBnqlZG/25432IosTWPS3LNjoqnTUb6zCaDFw6Mzhroz9/\nsh+DQWDlOs8tHN07j4IF+9DQEIcPH6axsVGP8RQFm3Y04S6zcu5Ef6rK4cRogMB0lJaOqmVX+Csb\nTcmwx7PH+1N/e/2Fa8RjIhu3v3PWQ8eaGiRpJjomEo5z8fQADqeFNRvrbvHolgarzcyqdR68k7Lp\nBeTQ38E+Ly0dle+ITf52omDB/nd/93f82Z/9mR5jKRqMJgO77ulATEgcfb0T72SIF56+CEB78mj+\nTmD1+lqqPE4unhrg3Ik+blwe4eaVUeqby9m0s+lWD2/J6Fgja6PnT8kb/cUzA0QjCTbf0YTJtPyy\nbrOxYZu8mV86M5A0ywwCsGrDO2Nzu50oKBbtlVdeoaGhgbVr1+o1nqJhzca6/7+9u4tpMkvjAP6v\ntIDDOKaK06DD6CwOG4gFRhPdgURtbeSjVlFRboymDUZvrCB+hKJGA8aAqJekxAjRZDTK2myI0Wym\nWiEIIsYFN6Q6bHAcjAVRMhSj9OvZC9dO2NJqzOgp5fndnSYn+acfT09P3/c56Or4DY/+PYBfe19g\n7I0H6Uu/mVI/OWXRUuQVKPH3c/fQ+nMvZNFSREmnQZX31ymx/fCOfPYXSPxOjt/6hvGT+Q6ipNMg\ni46aUr9aAEAx7yvI47/Af+zP8fiXFng8Psiio/Dd95F/MUG4eW9h1+v1GBoK/Me/uLgYZrMZZ8+e\n9T/2oafqRAKJRIK/rfwLrl56ALfLixU5yf4Vy1QyY2Yscjcq8Y+f/gXXmAc/qpIi6uDqD5W3KQ29\nPQPobP0Vvw+/RvrSRMTERs6pUR9CIpFg8Y/z0fLPX/DVzFjMmhOHH5Z+G1HHAU4WEvrIavzo0SPo\n9XrExsa+7Y8yMACFQoHLly9j9mz+hmaMMVE+urD/P7VaDYvFgpkzI/8SN8YYC2d/2nXsEolkSm3F\nMMZYuPrTVuyMMcbCA995yhhjEYYLO2OMRRgu7IwxFmGEFXa73Y7CwkLk5+ejoKAADx48EBXlvc6f\nP4+cnBzodDrU1NSIjhNUuPfsqa6uRm5uLtatW4ddu3ZhdPT9B2J/Ts3NzcjJyUF2djbq6upEx5mQ\nw+HA1q1bkZeXB51Oh3PnzomOFJTP58P69euxc+dO0VGCcjqdMBqNyM3NhVarRVdXl+hIE2poaMCa\nNWug0+lQWloKl2vig378SBCDwUAtLS1ERGSz2WjLli2iooTU3t5Oer2e3G43ERG9ePFCcKKJPXv2\njAwGA6lUKhoeHhYdZ0Ktra3k9XqJiOjEiRNUU1MjONEfvF4vaTQa6u/vJ5fLRWvXrqXe3l7RsQIM\nDg5ST08PERGNjo7S6tWrwzInEVF9fT2VlpbSjh07REcJ6sCBA9TY2EhERG63m5xOp+BEgRwOB6nV\nahobGyMiot27d5PFYgk5R9iKXSKRwOl8e+KK0+mEQhGe/SQuXLiA7du3Qyp9e/fcrFnh2ZJ3MvTs\nyczMxLRpb99yGRkZcDgc75nx+XR3d2P+/PmYN28eZDIZtFotrFar6FgB5syZg5SUFABAXFwckpKS\nMDg4KDhVIIfDgVu3bmHTpk2iowQ1OjqKzs5ObNy4EQAglUrx5Zfh2WLZ5/Ph9evX8Hg8ePPmDb7+\nOnQbZGH3+paVlaGoqAhVVVUgIly8eFFUlJAeP36Mzs5OnD59GjExMdi/fz+USqXoWONMxp49jY2N\n0Gq1omP4DQwMICEhwT9WKBRhvT0IAP39/bDb7UhLSxMdJcC7hca7xVs46u/vh1wuR1lZGex2OxYt\nWoTy8nLExobXgSAKhQJ6vR4rV67E9OnTkZWVhczMzJBzPmlhD9ZnpqSkBLdv30Z5eTk0Gg2uX78O\nk8mE+vr6TxknqFD9cLxeL0ZGRnDp0iV0d3ejuLhYyEpusvTsCfWaq9VqAEBtbS1kMhl0Ot3njheU\nyOfsY7x69QpGoxEmkwlxcXGi44xjs9kQHx+PlJQU3LlzR3ScoDweD3p6enD48GEolUocO3YMdXV1\nMBqNoqONMzIyAqvVips3b2LGjBkwGo1oamoK/fn55BtEQSxZsmTcePHixYKShFZUVEQdHR3+sUaj\noZcvXwpMNN7Dhw8pMzOT1Go1qVQqSk1NJZVKRUNDQ6KjTejKlStUWFjo3y8MF/fv3yeDweAfm81m\nMpvNAhMF53a7yWAwUENDg+goEzp58iStWLGC1Go1ZWVlUUZGBu3bt090rADPnz8ntVrtH9+9ezcs\n/w+4du0alZeX+8cWi4WOHj0aco6wPXaFQoGOjg4AQFtbGxYsWCAqSkgajQZtbW0AgL6+Png8Hsjl\ncsGp/pCcnIzW1lZYrVbcuHEDCoUCFoslLBuxNTc348yZM6itrUV0dHgdvKBUKvHkyRM8ffoULpcL\nV69exapVq0THmpDJZMLChQuxbds20VEmtGfPHthsNlitVpw6dQrLli1DdXW16FgB4uPjkZCQgL6+\nPgBAe3s7kpKSBKcKNHfuXHR1dWFsbAxE9EE5he2xV1RUoLKyEj6fDzExMaioqBAVJaQNGzbAZDJB\np9NBJpOhqqpKdKSQwrlnT2VlJdxuNwwGAwAgPT0dR44cERvqf6KionDo0CEYDAYQEQoKCsLyQ37v\n3j00NTUhOTkZ+fn5kEgkKCkpwfLly0VHm5QOHjyIvXv3wuPxIDExEcePHxcdKUBaWhqys7ORn58P\nqVSK1NRUbN68OeQc7hXDGGMRhu88ZYyxCMOFnTHGIgwXdsYYizBc2BljLMJwYWeMsQjDhZ0xxiIM\nF3bGGIswXNgZYyzC/Be68EGj7hfMcwAAAABJRU5ErkJggg==\n", "text/plain": [ - "[]" + "\u003cmatplotlib.figure.Figure at 0x7f385e198650\u003e" ] }, - "execution_count": 48, "metadata": { "tags": [] }, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "# Create TensorFlow Variables using Keras's Dense layer.\n", + "def f(x):\n", + " return tf.square(tf.sin(x))\n", "\n", - "wb = tf.keras.layers.Dense(units=1, use_bias=True)\n", + "def grad(f):\n", + " return lambda x: tfe.gradients_function(f)(x)[0]\n", "\n", - "# We can access the underlying TensorFlow variables using wb.variables.\n", - "# However, the variables won't exist until the dimensions of the input\n", - "# tensors are known. Once the dimensions of the input tensors are known,\n", - "# Keras can create and initialize the variables. Until then, Keras will\n", - "# report the variables as an empty list: [].\n", + "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n", "\n", - "wb.variables" + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(x, f(x), label=\"f\")\n", + "plt.plot(x, grad(f)(x), label=\"first derivative\")\n", + "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n", + "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", + "plt.legend()\n", + "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "docKLUaonYG_" + "id": "-39gouo7mtgu" }, "source": [ - "## Step 3: *Define the loss function*\n", + "## Gradient tapes\n", "\n", - "Our loss function is the standard L2 loss (where we reduce the loss to its mean across its inputs)." + "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n", + "\n", + "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n", + "\n" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -245,125 +182,42 @@ } }, "colab_type": "code", - "id": "0_w8ZJSCtuY7" + "id": "MH0UfjympWf7" }, "outputs": [], "source": [ - "def loss_fn(predictions, labels):\n", - " \"\"\"Calculates the mean L2 loss for our linear model.\"\"\"\n", - " return tf.reduce_mean(tf.square(predictions - labels))" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 348, - "status": "ok", - "timestamp": 1525154234538, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "RkNbXoXkpjVH", - "outputId": "e4688f3c-e29f-416d-f541-6d81953b5660" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\u003ctf.Tensor: id=1252, shape=(), dtype=float32, numpy=16.979801\u003e" - ] - }, - "execution_count": 50, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# Test loss function (optional).\n", + "def f(x, y):\n", + " output = 1\n", + " for i in range(y):\n", + " output = tf.multiply(output, x)\n", + " return output\n", "\n", - "loss_fn(wb(inputs), labels)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 51 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 418, - "status": "ok", - "timestamp": 1525154260083, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "K_7beXoHOU7t", - "outputId": "8f55c028-fe2b-4edb-ad68-a849afc60623" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "w: -0.311619\n", - "b: 0.000000\n" - ] - } - ], - "source": [ - "# At this point, the variables exist, and can now be queried:\n", + "def g(x, y):\n", + " # Return the gradient of `f` with respect to it's first parameter\n", + " return tfe.gradients_function(f)(x, y)[0]\n", "\n", - "w, b = wb.variables\n", - "print(\"w: %f\" % w.numpy())\n", - "print(\"b: %f\" % b.numpy())" + "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n", + "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", + "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", + "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "JVDWpL9VYWdP" + "id": "aNmR5-jhpX2t" }, "source": [ - "## Step 4: Create an optimizer\n", + "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", "\n", - "We'll use a `GradientDescentOptimizer` to fit our model." + "For example:" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -371,36 +225,48 @@ } }, "colab_type": "code", - "id": "DudNEebMKDWN" + "id": "bAFeIE8EuVIq" }, "outputs": [], "source": [ - "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)" + "x = tf.ones((2, 2))\n", + " \n", + "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n", + "# a single t.gradient() call when the bug is resolved.\n", + "with tf.GradientTape(persistent=True) as t:\n", + " # TODO(ashankar): Explain with \"watch\" argument better?\n", + " t.watch(x)\n", + " y = tf.reduce_sum(x)\n", + " z = tf.multiply(y, y)\n", + "\n", + "# Use the same tape to compute the derivative of z with respect to the\n", + "# intermediate value y.\n", + "dz_dy = t.gradient(z, y)\n", + "assert dz_dy.numpy() == 8.0\n", + "\n", + "# Derivative of z with respect to the original input tensor x\n", + "dz_dx = t.gradient(z, x)\n", + "for i in [0, 1]:\n", + " for j in [0, 1]:\n", + " assert dz_dx[i][j].numpy() == 8.0" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "YBeJYxY8YaiO" + "id": "DK05KXrAAld3" }, "source": [ - "### Step 5: Define a training step\n", - "\n", - "To fit model variables to the data we'll need to:\n", + "### Higher-order gradients\n", "\n", - "1. Calculate the gradients of the loss with respect to the model variables.\n", - "2. Use `optimizer` to compute updates to the variable values based on those gradients.\n", - "\n", - "To calculate the gradients, we use the [`tf.GradientTape`](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context manager\n", - "and its `gradient` function to compute gradients through computation conducted within its context:\n" + "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -408,163 +274,37 @@ } }, "colab_type": "code", - "id": "diDZfrMJM3OC" + "id": "cPQgthZ7ugRJ" }, "outputs": [], "source": [ - "def run_step(inputs, labels):\n", - " with tf.GradientTape() as g:\n", - " loss = loss_fn(wb(inputs), labels)\n", - " # Compute the partial derivatives of loss with respect to the variables\n", - " grads = g.gradient(loss, wb.variables)\n", - " optimizer.apply_gradients(zip(grads, wb.variables))\n", - " return loss" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1WWepgmJQOzc" - }, - "source": [ - "Repeatedly running the training step will nudge the variables towards the values that best fit the data (i.e., \"w\" will move closer to 3.0, while \"b\" will tend to 2.0):\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 51 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 380, - "status": "ok", - "timestamp": 1525154412590, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "ya5Qxz5XQlhU", - "outputId": "8dd47155-a6c1-44c5-c279-617c803f1723" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Values of w, b BEFORE applying gradients: 2.725763, 1.894334\n", - "Values of w, b AFTER applying gradients: 2.774932, 1.922555\n" - ] - } - ], - "source": [ - "w, b = wb.variables\n", - "print(\"Values of w, b BEFORE applying gradients: %f, %f\" % (w.numpy(), b.numpy()))\n", - "run_step(inputs, labels)\n", - "print(\"Values of w, b AFTER applying gradients: %f, %f\" % (w.numpy(), b.numpy()))\n" + "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", + "\n", + "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n", + "\n", + "with tf.GradientTape() as t:\n", + " with tf.GradientTape() as t2:\n", + " t2.watch(x)\n", + " y = x * x * x\n", + " # Compute the gradient inside the 't' context manager\n", + " # which means the gradient computation is differentiable as well.\n", + " dy_dx = t2.gradient(y, x)\n", + "d2y_dx2 = t.gradient(dy_dx, x)\n", + "\n", + "assert dy_dx.numpy() == 3.0\n", + "assert d2y_dx2.numpy() == 6.0" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "61TgeLVlKEQp" - }, - "source": [ - "## Step 6: Create a training loop\n", - "\n", - "Of course, now we can simply turn all of this code into a self-standing training loop. We'll also capture our loss and approximations of `w` and `b` and plot them over time." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 364 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 580, - "status": "ok", - "timestamp": 1525154278709, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "VukGe-huNaJ4", - "outputId": "c79c8e63-c781-451e-f74f-20815d8da49f" + "id": "4U1KKzUpNl58" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0.9409681558609009, 1.3733772039413452, 1.7128530740737915, 1.9793939590454102, 2.188689708709717, 2.3530514240264893, 2.4821391105651855, 2.583533763885498, 2.6631851196289062, 2.7257626056671143]\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAFKCAYAAABcq1WoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xd4U2X/BvD7ZLRpumlLS6EDgbKh\niIggU7aAgPhDRKsIUoYgiK++ioAguBARXmZBEARFUBGhiChIEQcqe+/RMlpGd9KRcX5/nDZtaFra\nkuY07f25rlw5zXmSfPMk5OY5Oec8giiKIoiIiMhhFHIXQEREVN0wfImIiByM4UtERORgDF8iIiIH\nY/gSERE5GMOXiIjIwVSOeJJbtzLs/pi+vlqkpOjt/rhkjf3sGOxnx2A/Owb7WRIQ4FnsOqcd+apU\nSrlLqBbYz47BfnYM9rNjsJ/vzWnDl4iIyFkxfImIiByM4UtERORgDF8iIiIHY/gSERE5GMOXiIjI\nwRi+REREDsbwJSIih/vxx61YtGi+3GXIhuFLRETkYA45vSQREZEtGzeux65dPwMAOnbsjOeeG45/\n/tmHFSuWwNVVA1/fGnjnndk4eHB/kdtUKueNMKesPDZ2C7p37wSNxkfuUoiInN6MGVOxdetmuz2e\nQiGgb98BmDFjdontbty4hgMH/sGKFV8AAKKjX0DXrt3x3XcbMH78q2jZshX27PkVaWmpNm/z8/O3\nW82O5nSbnTMy0jFixHN48cUX5S6FiIjuw9mzZ9G0aXOoVCqoVCo0b94S58+fRdeu3fHxxx/giy9W\noUGDhvDz87d5mzNzupGvp6cXOnTohF27duHkyRNo0qSp3CURETm1GTNm33OUWhYBAZ6lms1OEABR\nFC1/GwwGCIICvXv3Rdu27fDbb3H4739fxezZc2zeFhYWbreaHc3pRr4AEB09DgCwYsVSmSshIqLy\niohoiOPHj8FoNMJoNOLkyROIiGiI1as/g1KpwoABT6Jbt564fPmizducmdONfAGgR49eqFevHr79\ndgPefnsG/P2de/MDEVF1FBQUjFatHsKECdEwm0X07z8AQUG1EBgYhEmTxsHT0wuenp4YOvQ56PX6\nIrc5M0EsPOavIKXZ/FBW69d/jokTJ+LNN6di8uQ37P74JCnt5iO6P+xnx2A/Owb7WRIQ4FnsOqfc\n7AwAL774Ijw9vbBq1Qrk5ubKXQ4REVGpOW34enp6YtiwKNy8mYQfftgkdzlERESl5rThCwAvvTQa\nCoUCMTFL4ICt50RERHbh1OEbFhaO3r374ujRw/j7731yl0NERFQqTh2+ADB6tHTY0fLlS2SuhIiI\nqHScPnwfeaQ9mjdviR9/3Ir4+Ctyl0NERHRPTh++giAgOnoszGYzVq5cLnc5REQkk/Pnz1kGYe+8\n8xZycrLL/ViHDx9ESkqyvUorwunDFwAGDhyMgICa+PLLL5CZmSl3OUREJIM9e35FQkI8AGDmzA/g\n6qop92Nt27alQsPXKc9wdTdXV1e8+OJLmDPnfWzY8BVGjoyWuyQiIrqHYcMGY+3ajRBFEX36PIaF\nC5ehUaMmmDx5PN54420EBdWCyWTCnDnv4fr1azAajXjppTFo3boNtm+PxaZNG6FSqVG/fgQGDhyM\nH37YhD17foWvry+mT38LX3yxAZ9+Oge+vr44c+Y0UlNT8OyzL2Dbtq1IS0vFokXLIQjAzJlTkZWV\nhezsbLz66uvQ6TKxd28cLl26iNmz5+DMmZP4+ut1UCpVaNiwMSZMePW+X3uVCF8AeOGFkZg/fy5W\nrFiKF198CQpFlRjUExFVOPcZU+FqxykFoRDg3ncAdPeYrKFhw8a4ePECjEYDGjVqjOPHjyIiohGS\nk5MRFFQLAPDLLz/Bz88fb701HampqZg4cQzWrPkaX3+9DnPmzEdgYBC2bduCOnXqoG3bdujSpRua\nNGlm9TxKpQoLFizFzJlTcezYUSxYsASzZk3DwYP7ER5eF/36DUSnTl1w4MC/+PLLNXjvvY9Rv34E\nJk9+A15eXlizZiWWLfscLi4umDbtTRw9ehgtWkTeVxdVmfANCAjA4MFDsH79Ouza9TN69Ogtd0lE\nRFSCyMgHceLEMeTm5uCpp57Gnj270bLleURENLS0OX78KI4cOYSjRw8DAHJycmAwGNC9ey9MmfI6\nevXqg+7de5W4iblxY2n2Oz8/f8tMSL6+ftDpMlGjhh/WrPkM69evhcFggEZj/TiXLl1EUlIiJk8e\nDwDQ6TKRmJiIFi3u77VXmfAFgFGjxmL9+nWIiVnK8CUiKiXdjNn3HKWWRUCAJ3SlOLdzq1atsW7d\nauTkZKNfvwHYtm0rjh07ggcffMjSRqVS4/nnRxT5To+KehE9evRBXNxOvPLKWCxeXPwOt0ql0uay\nKIrYuPEr+PvXxLRps3D69EksWjTf6r5qtbSped68Rfd8PWVRpbbNNmvWHB06dMJvv+3GqVMn5S6H\niIhKEBoahqSkJGRm6qDVusPPzw9798ZZhW+TJs3w++97AAApKcmIiVkMs9mMmJjF8Pf3x9Chz6FZ\ns+ZITEyEIAgwmUxlqiEtLRW1a9cBAOzZsxtGoxEAoFAoYDKZEBoajsuXL1l2vlq5Mga3bt2879de\nqvA9e/YsunfvjnXr1gEAbty4gaioKAwbNgwTJ06sVBMbcK5fIiLn4evri6CgIABS0N64cQM1awZa\n1j/2WHe4uWkxZswIvPHGq2jRIhIKhQJarTtGj34REyeOhSAIaNAgAi1btsL8+R9j//5/Sv38vXv3\nxYYNX+LVV19G06bNcOfOHWzbtgWRkQ9i6tT/4vr1a5g48TX85z8TMXbsCKSlpcLfP+C+X/c9pxTU\n6/UYPXo0wsPD0bBhQzz33HN466230KlTJ/Tp0wfz5s1DUFAQhg0bVuxjVMTUUsVNWWUymdCu3YO4\nceM6Dh06xbl+7xOnBnMM9rNjsJ8dg/0sua8pBV1cXLBixQrUrFnTctvff/+Nbt26AQC6du2Kv/76\nyw5l2odSqcSoUWOQk5ODtWs/l7scIiKiIu4ZviqVqsjeX1lZWXBxcQEA+Pn54datWxVTXTk988xz\nnOuXiIgqrfve27k0U/n5+mqhUinv2a6sihvSBwR44qWXRuLTTz9FXNxPePbZZ+3+3NVJSZtOyH7Y\nz47BfnYM9nPJyhW+Wq0W2dnZ0Gg0SEpKstokbUtKir5cxZXkXr8pDBv2IhYsWIC5cz9Bjx79IQiC\n3WuoDvjbjWOwnx2D/ewY7GfJff3ma0v79u2xY8cOAMDPP/+Mjh07lq+yChQWFo5evR7H4cOH8M8/\nf8tdDhERkcU9w/f48eOIiorC999/jy+++AJRUVEYP348Nm/ejGHDhiE1NRUDBw50RK1lxrl+iYio\nMrrnoUb24MhDjQoTRRHdunXEyZPH8e+/RxESEmr3Oqo6bj5yDPazY7CfHcPe/RwXtwtdunSz2+M5\nit03OzsLzvVLROTcbty4jp07d8hdht1V6fAFgEGDnoK/fwDWrVvDuX6JiCqRYcMGw2QywWg0okeP\nTjh9Wjot8OTJ45GYeAMAMG/eRzh8+CA+/3wFVq6MwaxZ0zFu3EvYv/8fTJ36huWx+vaVRsaXLl3E\nK6+MwcSJY/HWW68hI6Nybumo8uGbP9dvenoaNmz4Su5yiIgqpRqtm9m8aAptNfQcN8pmG8/o4ZY2\nmrWrgfDwUj1n/pSC586dsUwpaDabraYUfOaZKERGPogXXxwFADAaDViy5LNip42dP/9jvP76FCxY\nsBRt2jyCTZs2lqs/KlqVD19AmutXOlPXUpjNZrnLISIiFEwpeOzYETz11NM4efIELlywnlLwbvnT\nAxbn5MkT+Oij2Rg/Pho7dvxomRChsqlSUwoWp2bNmnjyyf/D119/ybl+iYhsSD5w/J5tMpasuGeb\n7Kjh8Jw8AbDTlIJ3U6vVAFDk3A35sxFpNBosXBhT6c/tUC1GvoA01y8AxMRwtiMiosqgNFMK5k/t\ndzd3d3fcuXMbAHD+/Dno9dLJnOrXb4B9+/4EAOzcuaNMMxw5UrUJ3+bNW+DRRztyrl8iokrkXlMK\nhoXVxZkzp/G//31idb/69SOg0bhhzJgR2LHjRwQFBQMAJk78D9au/Rzjx0fjxx9jS9yELacqfZzv\n3bZv34YXXngGzz33AubNW2j3mqoiHhfpGOxnx2A/Owb7WVJtj/O9W8+evREWFo5vvvkat2/flrsc\nIiKqpqpV+HKuXyIiqgyqVfgC0ly/Hh6enOuXiIhkU+3C19PTC88+G4WkpERs2fK93OUQEVE1VO3C\nFwBGjhwNQRCwfPkSOGB/MyIiIivVMnzDw+uid+++OHz4EP79t3IeA0ZERFVXtQxfgHP9EhHJ6ccf\nt2LRovl2eSydLhP//LMPALB27WocP3603I+VmJiIkyfvfbav+1Vtw7ddu0fRrFkLxMb+gISEeLnL\nISKicjpz5rQlfKOihqNZsxblfqyDB//FqVMn7FVasarFuZ1tyZ/r95VXxmLVqhV4551ZcpdERFSt\n3LhxDf/5zyu4eTMJQ4YMQ79+A6zWf/fdRuzc+RMEQYGOHbvgmWeew9mzp/HJJx9BrVbDxcUFM2d+\ngHnz5kCv1yEkJBTHjx9Fly7dkJaWisOHDyI1NRWXLl1EdPRY7Ny5A5cvX8L06bPRtGkzLFw4DydP\nnkBubi4GDhyMDh06Y9Wq5VCpVAgMDELt2iH49NM5EAQBWq0WU6bMgKdn8SfOKItqG76ANNfvu+9O\nx7p1a/Daa/+Fh4eH3CURETncjBmu2LrVfnGgUAB9+7pixoycEtslJMRj1aovodNlYvjwYejb9wnL\nhAjXr19DXNwuLFmyEgAwduxIdO3aHT/+uBWDBj2F3r374sCBf5GcfAfDhkXh4sULGDDgSatNzgkJ\n8Viy5DNs3boZ69atxqpVX2L79q3YuXMH6tdvgKCgYEyYMBk5OdkYMmQg+vcfiD59+sHHxwcdOnTG\nxIlj8frrUxASEopNm77Bpk0b8cILI+3SR9U6fPPn+v344w+wceN6jBgxSu6SiIiqjRYtIqFSqeDt\n7QN3d3ekpaXBx8cHAHDq1AlcvZqACRNGAwD0eh0SE6+jQ4fOmDv3QyQkxKNbtx4ICwvHiRPHbD5+\no0ZNIAgC/Pz8Ua9eAyiVSvj6+kGnOwJXV1ekp6dhzJgRUKlUSE1NKXL//OkJAcBgMKBx4yZ2e+3V\nOnwBaa7fBQs+wYoVSzF8+MhiJ2gmIqqqZszIuecotSykczuX5vGsp/0rPAugSqVGu3aP4o033i5y\nr88++wJ//rkXs2fPwPjxk4p9dKVSaXNZFEUcOnQABw/ux6JF0mbmHj06Frl/RU5PWO2TJn+u3wsX\nzuPXX3+RuxwiomrjxImjMJlMSElJQVZWFry8vC3rGjZsjIMHDyA7OxuiKGL+/LnIycnGd99tQHp6\nGnr27IOnnx6Gs2dPQxAEm9MOliQtLRU1awZCpVLh99/3wGQyw2AwWE1hWJHTE1b7kS8gzfX79ddf\nIiZmCbp37yV3OURE1UJoaDimTXsT164lIDp6nNUIMygoCEOGPIOXXx4FhUKBTp26wNVVg9q1QzBt\n2pvw8PCAWq3GlCnvIDU1BcuWLURAQM1SP/dDD7XFl1+uwfjx0ejYsTPat++AuXM/QPfuPTF79gz4\n+Phi4sT/YM6c9/Dll2vg4uKKGTNm2+21V6spBUsyaFBf/PHHXvz2299o1Kix3R7X2XFqMMdgPzsG\n+9kx2M8STilYCtHR0kk3VqxYKnMlRERU1TF88xSe6/fOnTtyl0NERFUYwzdP/ly/2dnZnOuXiIgq\nFMO3EM71S0REjsDwLSR/rt/ExBvYunWz3OUQEVEVxfC9S/5cvzExiznXLxERVQiG71041y8RUcUr\nzZSCu3fvdFA1jsfwtYFz/RIRyW/dujVyl1BhGL42cK5fIqKKlz+l4PPPP43Y2B+s1n311Rc4f/4s\npkx5HQcP7scbb0zC+PHROH36FPr27WZpN3XqGzh4cD/0eh2mTn0DEyeOxfjx0Th//pyjX06ZMHxt\nyJ/r12w2Y9WqFXKXQ0RU4Vq3drd5WblSbWkzbpzGZpvoaI2lzdq1aoSHl+45ExLi8eGH87BwYQxW\nroyx2s9m2LDn4eHhgfff/xgAcOHCecybt6jYMxBu3Lgebdu2x4IFS/Haa29i0aJPy94JDsTwLcbA\ngYPh7x+AdevWIDMzU+5yiIiqHFtTChanfv0GcHFxKXb9sWNHsXnzdxg/PhqffPIhdLrK/b3NiRWK\nodFoMHz4SMyd+yHn+iWiKu/AAd092yxZkn3PNlFRBkyerMGtW6V51uKnFLybWq22ebvRaMxbr8Kr\nr76OZs1alOaJZceRbwleeGEkXFxcsGLFUpjNZrnLISKqUkqaUhAAzGbbh3sKgoDs7GxkZ2fj7Nkz\nAIAmTZrht9/iAACXLl3E11+vq9Da7xfDtwSBgYEYNOgpzvVLRFQB8qcUnDRpbJEpBQEgIqIhRo16\nvsj9Bg58CtHRL+D992eiYUPpN+Cnnnoa164lYNy4l/DRR7MRGfmgQ15DeXFKwXs4duwIunXriM6d\nu+Kbb3649x2qGE4N5hjsZ8dgPzsG+1nCKQXvQ/PmLdG+fQfs2bMbp0+fkrscIiKqAhi+pcC5fomI\nyJ4YvqXQq1cfhIZyrl8iIrIPhm8pSHP9jkZ2djbWrVstdzlEROTkGL6lNGxYFDw8PLFy5XIYDAa5\nyyEiIifG8C0lT08vDBv2HOf6JSKi+8bwLQPO9UtERPbA8C2DunUfQK9ej+PQoYPYv59z/RIRUfmU\nK3x1Oh3Gjx+PqKgoDB06FHv37rV3XZVWwVy/POyIiIjKp1zh+/3336Nu3bpYu3YtFixYgPfee8/e\ndVVa7dt3QNOmzREb+wOuXk2QuxwiInJC5QpfX19fpKamAgDS09Ph6+tr16IqM0EQMHr0OJhMJs71\nS0RE5VLuczuPHDkS8fHxSE9PR0xMDCIjI4ttazSaoFIpy11kZZOdnY2wsDDk5ubi6tWrcHd3l7sk\nIiJyIuWaz/eHH35AcHAwVq5cidOnT2PKlCnYtGlTse1TUvTlLrA4cp+4+/nnR2Du3A+xePFyvPji\nS7LVUdHk7ufqgv3sGOxnx2A/S+w+scLBgwfRoUMHAECjRo1w8+ZNmEym8lXnpDjXLxERlVe5wjcs\nLAxHjhwBAFy7dg3u7u5QKqvOZuXSyJ/r9/z5c9i9e6fc5RARkRMpV/g+/fTTuHbtGp577jm89tpr\nmDFjhp3Lcg7R0WMBADExS2SuhIiInEm5fvN1d3fHggUL7F2L08mf6zcu7lecPn0KjRo1lrskIiJy\nAjzD1X0qmOt3mcyVEBGRs2D43qeCuX7XIzmZc/0SEdG9MXzvU+G5fteuXS13OURE5AQYvnbAuX6J\niKgsGL52wLl+iYioLBi+dsK5fomIqLQYvnbCuX6JiKi0GL52xLl+iYioNBi+dsS5fomIqDQYvnbE\nuX6JiKg0GL52NnDgYPj7B2Dt2tXQ6XRyl0NERJUQw9fONBoNhg8fibS0VGzcuF7ucoiIqBJi+FYA\nzvVLREQlYfhWgMDAQAwcOBjnz59DXNwuucshIqJKhuFbQTjXLxERFYfhW0FatIhEu3aPYvfuXThz\n5rTc5RARUSXC8K1AnOuXiIhsYfhWoN69H0doaBjn+iUiIisM3wqkVCrx0kujkZWVhXXr1shdDhER\nVRIM3wo2bFgU3N09ONcvERFZMHwrmJeXN4YNew43blxHbOwPcpdDRESVAMPXAV56aQwEQcAHH8xC\nZmam3OUQEZHMGL4OULfuA3j55Ym4fPkSpk79r9zlEBGRzBi+DvLmm1PRokUkvvpqLbZu3Sx3OURE\nJCOGr4O4uLhg2bKVcHNzw+TJr+Datatyl0RERDJh+DpQ/foNMGvWh0hLS8X48aNhMpnkLomIiGTA\n8HWwqKjh6NOnH/74Yy8WL/6f3OUQEZEMGL4OJggC5s1biMDAIHz44SwcPnxQ7pKIiMjBGL4y8PPz\nw6JFMTAajRgzZiR0Op3cJRERkQMxfGXSuXNXjBv3Ci5evIBp096UuxwiInIghq+M3nprGpo1a4F1\n69Zg61ae/YqIqLpg+MrI1dXVcvjRa69NwPXr1+QuiYiIHIDhK7OIiIaYOfN9pKamYsKEMTCbzXKX\nREREFYzhWwm88MII9O79OPbu3YMlSxbKXQ4REVUwhm8lIB1+tAg1awbigw/exdGjh+UuiYiIKhDD\nt5Lw9/fHwoXLYDAYePgREVEVx/CtRLp27YbRo1/G+fPnMH36FLnLISKiCsLwrWSmTp2Bpk2bY+3a\nz/Hjj7Fyl0NERBWA4VvJ5B9+pNFoMHnyeCQm3pC7JCIisjOGbyXUsGEjzJjxHpKTk/Hyy6N5+BER\nURXD8K2kXnzxJfTs2Rt798Zh2bLFcpdDRER2xPCtpARBwKefLkZAQE28994MHDt2RO6SiIjIThi+\nlVhAQAAWLlxqOfxIr9fLXRIREdkBw7eSe+yxHoiOHotz587inXfelrscIiKyA4avE5g6dSYaN26K\nNWtW4qeffpS7HCIiuk/lDt8tW7bgiSeewJNPPom4uDg7lkR302g0WLZsJVxdXfHqqy8jKSlR7pKI\niOg+lCt8U1JSsHjxYnz11VdYtmwZdu3aZe+66C6NGzfBjBmzcefOHc5+RETk5MoVvn/99RfatWsH\nDw8P1KxZE7NmzbJ3XWTDiBHR6N69J+LifsXy5UvkLoeIiMqpXOF79epVZGdnY8yYMRg2bBj++usv\ne9dFNgiCgAULlsLfPwCzZ8/AsWNH5S6JiIjKQRBFUSzrnZYvX46DBw9i0aJFuH79Op5//nns3r0b\ngiDYbG80mqBSKe+7WJJs374djz/+OBo3boz9+/dDq9XKXRIREZWBqjx38vPzQ6tWraBSqRAaGgp3\nd3ckJyfDz8/PZvuUFPsfnxoQ4IlbtzLs/rjO4KGHOuCll0bjs89iMH78RHz00bwKe67q3M+OxH52\nDPazY7CfJQEBnsWuK9dm5w4dOmDfvn0wm81ISUmBXq+Hr69vuQuksps+fRYaN26Czz//DD//vF3u\ncoiIqAzKFb6BgYHo1asXhgwZglGjRmHq1KlQKHjIsCNpNBosXSodfjRx4jgkJSXJXRIREZVSuX7z\nLauK2PzAzRqSFSuW4u23/4uuXbth/frv7P6fIPazY7CfHYP97BjsZ4ndNztT5fHSS2Pw2GPdsXv3\nLnz22TK5yyEiolJg+Dq5gsOP/PHuu9Nx4sRxuUsiIqJ7YPhWAYGBgZg/fzFyc3MxduxIZGVlyV0S\nERGVgOFbRfTs2QcjRozC6dOn8O670+Quh4iISsDwrULeeWc2GjZshJUrl2Pnzh1yl0NERMVg+FYh\nbm5uWLZsFVxcXPDKK+Nw8+ZNuUsiIiIbGL5VTNOmzTBt2kzcvn0LEyeOhQOOJCMiojJi+FZBo0aN\nRZcuj2HXrl+wcmWM3OUQEdFdGL5VkEKhwMKFy+Dn54eZM6fh1KmTcpdERESFMHyrqMDAIHz66WLk\n5ORgzJgRyM7OlrskIiLKw/Ctwnr3fhzDh4/EqVMnMWvWdLnLISKiPAzfKm7GjPcQEdEQK1Ysw65d\nP8tdDhERgeFb5Wm1WixdutJy+NGtW7fkLomIqNpj+FYDzZu3wNtvz8CtWzcxadI4Hn5ERCQzhm81\nMXr0OHTu3BW//LIDq1atkLscIqJqjeFbTeQfflSjRg3MnDkVp0+fkrskIqJqi+FbjQQF1cKnny5G\ndnY2xowZycOPiIhkwvCtZvr06Yvnnx+BkyeP4733ZspdDhFRtcTwrYZmznwP9es3QEzMYuzevUvu\ncoiIqh2GbzXk7u6OZctWQq1WY8KEMbh9+7bcJRERVSsM32qqRYtIvPXWdNy8mYRXX32Zhx8RETkQ\nw7caGzduAjp27IIdO7Zj9eqVcpdDRFRtMHyrMYVCgUWLlsHX1xfvvDMFZ8+ekbskIqJqgeFbzdWq\nFYx58xYhOzsbo0ePQE5OjtwlERFVeQxfQt++/REVNRwnThzD+++/K3c5RERVHsOXAADvvvsB6tWr\nj6VLFyIu7le5yyEiqtIYvgSg4PAjlUqFCRPG4M6dO3KXRERUZTF8yaJly1Z4881pSEpKxKuvjufh\nR0REFYThS1bGj5+IDh064aeftuGLLz6XuxwioiqJ4UtWpMOPYuDj44Pp09/C6dOn5S6JiKjKYfhS\nEcHBtfHJJwuRlZWF/v374+LFC3KXRERUpTB8yab+/Qdg8uTXcf78eTz+eDf8/fc+uUsiIqoyGL5U\nrDffnIbly5cjLS0Ngwf3w/fffyt3SUREVQLDl0o0atQorF//HVxdNRg9egQ+/fRj7gVNRHSfGL50\nT126PIbY2J9Rp04IPvhgFiZNehm5ublyl0VE5LQYvlQqjRs3wfbtuxAZ2Qrr16/DM88MRlpaqtxl\nERE5JYYvlVpgYBC+//5H9O7dF3v37kHfvj1w5cplucsiInI6DF8qE3d3d3z++TqMGTMeZ8+eQZ8+\nj+HAgX/lLouIyKkwfKnMlEol3n33fXz44SdITk7GoEF9sXXrD3KXRUTkNBi+VG4jRozCunUboFSq\nMHJkFBYtWsA9oYmISoHhS/ele/de2LLlJ9SqFYx3352G//xnEgwGg9xlERFVagxfum/Nm7fATz/9\nimbNWmDt2s/x7LP/h/T0NLnLIiKqtBi+ZBe1agVjy5af0KNHL8TF/Yr+/Xvh6tUEucsiIqqUGL5k\nNx4eHlizZj1GjozGqVMn0bv3Yzhy5JDcZRERVToMX7IrlUqFDz6Yi9mzP8StWzcxYEAfbN++Te6y\niIgqFYYvVYjo6HFYvforAMDw4cMQE7OYe0ITEeW5r/DNzs5G9+7dsWnTJnvVQ1VInz598cMP2xEQ\nUBPTpr2FKVNeh9FolLssIiLZ3Vf4Ll26FN7e3vaqhaqgli1b4aeffkXjxk2xcuVyvPDCM8jMzJS7\nLCIiWZU7fC9cuIDz58+jS5cudiyHqqI6dUIQG7sDXbo8hl9+2YEnnuiNGzeuy10WEZFsBLGcP8RF\nR0dj2rRp2Lx5M2rXro0nn3xblRT0AAAgAElEQVSy2LZGowkqlbLcRVLVYDAYMH78eCxfvhy1a9dG\nbGwsIiMj5S6LiMjhVOW50+bNmxEZGYmQkJBStU9J0ZfnaUoUEOCJW7cy7P64ZM3e/Txr1seoVSsU\nM2dOxaOPdsBnn61G9+697Pb4zoqfZ8dgPzsG+1kSEOBZ7LpyhW9cXBwSEhIQFxeHxMREuLi4ICgo\nCO3bty93kVQ9CIKAl19+BaGhYXj55VF47rmn8f77H2PEiFFyl0ZE5DDlCt/58+dblhcuXIjatWsz\neKlM+vcfgODgYERFDcWbb76GS5cuYsaM2VAq+fMEEVV9PM6XZNO6dRts374LERENEROzGC+++Bx0\nOp3cZRERVbj7Dt8JEyaUuLMVUUnCwsKxbdsv6NixM376aRsGDnwcSUmJcpdFRFShOPIl2Xl7+2D9\n+u/wzDPP4ciRQ+jTpxtOnTopd1lERBWG4UuVgouLC+bPX4wpU6bj6tUE9OvXE7t375K7LCKiCsHw\npUpDEARMmvQfxMSsQm5uDoYNewpr166WuywiIrtj+FKlM2jQU/j2263w9vbGa6+9glmz3oHZbJa7\nLCIiu2H4UqXUtu0j+PHHXahXrz4WLvwUo0YNR1ZWltxlERHZBcOXKq0HHqiHH3/ciXbtHsXWrZvx\n5JN9cevWLbnLIiK6bwxfqtR8fWtg48bNeOqpp3HgwH706dMNZ8+ekbssIqL7wvClSs/V1RWLFy/H\n66+/hfj4y+jbtwd+//03ucsiIio3hi85BUEQ8Prrb2HRohjo9ToMGTIQX3/9pdxlERGVC8OXnMqQ\nIc/gm29+gIeHB155ZSw+/HAWyjkrJhGRbBi+5HTat++AH3/chbCwcMyb9zHGjh2J7OxsucsiIio1\nhi85pfr1G2D79l/Rpk1bbNr0Lf7v/wbgzp07cpdFRFQqDF9yWv7+/vjuu60YOPBJ/P33X3j88W64\nePG83GUREd0Tw5ecmkajwbJlqzBp0n9w6dJF9OnTjXtCE1Glx/Alp6dQKDBlynTMn78YGRkZePLJ\nfnjhhWE4ffqU3KUREdnE8KUqY9iwKGzZ8hPatGmL7dtj0aVLO0yYMAbx8VfkLo2IyArDl6qUhx56\nGLGxP2Pdug1o1KgJNmz4Cu3aPYgpU17HzZs35S6PiAgAw5eqIEEQ0LNnH/z66+9YuvQzBAfXxmef\nxeDhh1vigw/eRVpaqtwlElE1x/ClKkuhUGDw4CH4888DmDPnU3h6euLTT+eiTZsWWLhwPvR6vdwl\nElE1xfClKk+tVmP48JH4++/DmDp1JgBg1qzpaNs2EqtXr4TBYJC5QiKqbgTRAefmu3Urw+6PGdCm\nOUzmoqXrx72C7JHRAADPcaOg/vuvIm0MrR9CxvLVAADN2tXQzp9r8zmS/zoIuLhAee4svIc+abNN\nxryFMHTuCgDw6dUFitu3i7TJHvIM9P99GwDg/s7bcI39oUgbU2gY0r7fBgBw2b4NHlP/a/P5Urfu\ngDm4NoTUFPh262izjW7KdOQMHgIA8Hr2/6CysddvbtfuyJw7HwDgtnA+3FZ/VqSNqNVCdfoUbt3K\ngGr/P/AaPcLm86WvWgtjy1YAAN+2kRCMxiJtsqLHImv0ywAAj0kvw2XvniJtjM1bIn21dL5m16+/\nhPvHH9h8vuQ9+wAPDyguX4LP4P4222TOmYfcbj0BAD79ekJx47plndlsRkZ6OlZl6fG60Yjw8Lr4\nrmEjtDxxHBAEq8cx1wpGauzPAACXXT/D443JNp8v9butMIfXBTIzUaPzIzbb6F5/CzlDnwUAeA1/\nFqpjRyzrlAoBJrOI3I6dkTl/MQDALWYx3JYvLfI4okqFlL8PAwBURw7Ba0SUzedLj1kF40MPAwB8\nOz4MwcZIP2v4S8iaMAkA4PGfSXDZvbNIG2Ojxkj/8hsAgOt3G+H+/rs2ny9l116IPr5QXL8Gn/69\nbLbJnP0Rcvv0BQB4D+oLpY2d4XL6DYBu5nsAAO1H70GzcX2RNmZ/f6TuiAMAqPfshufkCTafL+3r\nTTA1iAByc1Gj3YOWfi5MP+k/yI4aDgDwjB4O9YH9RR7H0LYdMpasAABoVi6Hdsn/bD5f8oHjAADl\nyRPwjnraZpuMRTEwtHsUAODb9VEI6WlF2mQ/+zz0k98AALhPeR2uO7YXaWOqVx9pGzcDAFy2bobH\njKk2ny9l+68Qa9aEcPMmfPs8ZrNN5ozZyO0/EADgPWQglBeKHi+f06sPdO9/DADQzpsDzZdfFGkj\nenkjZfcfCAjwROqWn+A5frTN50tbuwGmJk0BADVaN7PZRs7vcnsJCPAsdp3Krs9E5AQUCgW8fXzw\nwpBncBoivvjic+y4fAmBajW8vX3g5uYmd4lEVMU578g3wLNCHpesVYd+vnLlMj7++AN8883XEEUR\nDz/8CKZOnYFHHmnvsBqqQz9XBuxnx2A/S0oa+fI3X6r2wsLCsWhRDPbs2Yc+ffrhn3/24YknemPo\n0CdxrNCmYSIie2H4EuVp1Kgx1qz5Ctu370KHDp3w66870a1bR0RHD+c5o4nIrhi+RHdp3boNvvtu\nKzZu3IzIyFbYvHkTHn20DV577RVcv35N7vKIqApg+BLZIAgCunR5DDt2xGHlyrV44IF6WLt2Ndq2\njcQ777yN5GROX0hE5cfwJSqBIAjo338A9uzZhwULliAgoCaWLl2Ihx5qgblzP0RmJncqIXJqZjOg\n0wGZmQ59Wu7tTCViP1vLycnBmjUrMX/+XNy+fRv+/v6YOPE1vPDCSGg0mnI/LvvZMdjPjnHf/SyK\ngMEAITsLQlYWoNdDyM6GkKWHkJUFITsL0GdJfxe6HdlZEPRZBW2yCrXRF2pT+PbsbOkpFQqkffUt\nDI91t1MvlLy3M8OXSsR+ti0zMwMxMUuwZMlCZGSko3btOnj99bcwZMgzUKnKfvg8+9kx2M92IoqA\nTgdBp4Ogy4Sg00Ghy4SgywR0OngrTMhISraEoJCVBdwVggXhWNBG0OuB/DA1mexbsloN0U0L0c0N\n0GggarUQNRrLbaKPL3RvvwNznRC7PSfDl8qN/Vyy5OQ7+N//PsWqVcuRnZ2NBg0i8Oab09Cv3xMQ\n7jpbVknYz45RLftZFKWQKxSUQmZmwbLuruXM/GUdhMyMguXC6/Q6CHaIDlEQADc3KfzcCsIQbm4Q\nNW4QtXnrNG557Qq3KRScGqkd7g5UjRugzbsux3+K7xfDl8qN/Vw6169fwyeffISvvloLk8mEyMhW\nmDLlHXTu3LVUIcx+dgyn6Oe8sFSkp0FIS4OQnlZ8GBYXmlZ/Z0Iwm8tfjiBAdPeA6O4uXTw8Cy17\nWK9zl9Z5BvkhzaQoCNG84LQEZn6AuroWOaVrVcLwpXJjP5fNhQvn8NFH72Hz5k0AgA4dOuHtt99B\n69ZtSrwf+9kxHNLPZjOEjPS84Ey3CtGC5XTp70LLVuttnB+9tEStuyUMzR6eQKHl/Nvh7gGzR35o\neuSFqDtErUeRUIVWW+aA5OdZwvClcmM/l8+xY0fw/vvvYteuXwAAvXv3xVtvTUPjxk1stmc/O0ap\n+jk3F0JaGhQZxQVlqnVopqVBUWhZyMwo8yZZUauF2csbopcXRC9vmL3zl32k68JBefeIs/Cy1h1Q\nyH8QCz/PEoYvlRv7+f7s2/cnZs+egX/+2QdBEPDUU0/jjTemICws3Kod+9nOjEYIKSlQJN+BIvkO\nhDvStWeuDvrrN4sP0Yx0aeegMhAFAaKnF0Rvbyk4ve5a9vKG6C2FqNnLO29d/rJ0O9TqCuoIefDz\nLGH4Urmxn++fKIrYuXMH3nvvXZw8eRxqtRpRUcPx6qtvIDAwEAD7uURmM4TUFCiSky0hWjhQFcl3\nICTfgeJO3nVyMhRpqaV+eFGthujtDbMlQPOCMi9ArZel8LSMUr29IXp4VorRZmXCz7OE4Uvlxn62\nH7PZjM2bv8OHH87G5cuXoNVqER09Di+//Arq1w+pHv0sitIo0yo4k6G4Yx2iluWUZAgpKaXaYUhU\nqWCu4QfRz0+6riFdm/1qWJa9wmsjRVRLI868wIWbW5Xe6UcO/N6QMHyp3NjP9mcwGPDVV2sxd+6H\nSEpKhI+PDyZNmoQePfqhQYMIucsrPVGU9qS1MRq1CtSUZGn9HWm5NDsTiQoFxBo1pPDMD1K/QoFa\no4YUsr41LIErenrdM0T5eXYM9rOE4Uvlxn6uOHq9HitXLsfChfOQmiptJm3YsBH69n0C/foNQNOm\nzcp0rLDdmM0Qbt+GMukGFEmJUCQlQZFYaDnphnR96yaE3NzSPaSvr1WIWo9M85Z9a0D0ywtTb58K\n2ZTLz7NjsJ8lDF8qN/ZzxcvISMcff/yKr77agLi4XcjOO91deHhd9Os3AP36PYFWrVrffxAbjVDc\nvpUXoolQJCbeFah5t926WeLZhUQXF5gDg2AOCIDZz79oiNbwsx61+vjIcoIDW/h5dgz2s4ThS+XG\nfnaM/H7OzMzErl0/IzZ2C375ZQf0eh0AIDi4Nvr27Y9+/Qbg4YcfgVKpLLizwQDFrZt5o9OkvBC9\nAcXNJOuQvX2rxN9ORY0G5ppBMAcGwhxUC6bAQClk8y9BtWAODIToW8NpfyPl59kx2M8Shi+VG/vZ\nMWz1c1ZqKvbH/oCD27bg8l9/wFuvRy0AdTUaNK/hh1C1Gp46HRR3bpd4XKmo1cJcMxCmoFp5IRpk\nFbJSuAZKm3qdNFRLi59nx2A/S0oK38qxLYiousnKgjIhHsqEK1DExwOpt+B58Yr1ZuDkZIQCePLu\n+2ZnA9evIQPARYUCBv8AuNVvAP9mzSEE15HC1TJaDZIOhanioUrkbBi+RBUhOxvKawlQXLmSF7Lx\nUMRflpbj46G4dbPIXfInJDR7ecMcGAhj0+Yw1wwsGK3mBaohoCb+jr+CH3b9jG3btuLGjevArZvw\nOHYUPXr0RL/QAXisVWu4u7s79jUTUalxszOViP1cjNxcKK4m5IXpFSjyrqWQvQJlUqLNu4lqNcy1\n68AUGg5TaCjMIaEwhYbBq2kE7rh6wRwYJJ1Lt5TMZjMOHtyP2NgtiI3dgvj4ywAANzc3dO3aHX37\n9kevXn3g5eVtj1ft9Ph5dgz2s6RCfvOdM2cODhw4AKPRiNGjR6Nnz57FtmX4Oq9q288GAxTXrlqP\nWuPzlhPiobhx3ebvrKJSCXPtEJhCpVA1h4TCFBIKU2g4zKGhUrgW3lkqjz36WRRFHD9+DNu2/YDY\n2C04e/YMAECtVqNTpy7o128AevfuCz8/v/t6HmdWbT/PDsZ+ltg9fPft24eVK1dixYoVSElJwaBB\ngxAXF1dse4av86qy/Ww0QnHjuvWoNX85IR6K69ds7hksKhTSyDUktFCwhsEcGibdViu4XIfVVEQ/\nnz17BrGxUhAfP34UAKBUKtG+fQf07fsEHn+8H4KCatn1OSu7Kvt5rmTYzxK7h6/JZEJOTg60Wi1M\nJhPat2+PP//80/rwh0IYvs7LafvZZIIi8YYUpFcuW0asls3E167aPJZVFASYawVbNgebQkKlYM1f\nDq5dISfBr+h+vnz5ErZt24rY2B9w4MC/AABBEPDQQw+jX78B6Nu3P0JDwyrs+SsLp/08Oxn2s6RC\nDzXasGED9u/fj48//rjYNhXxJrRp4wmzjZHJuHG5GDnSkLeswd9/F/0PQevWJixfLp3IYO1aNebP\nd7H5HH/9pYOLC3DunAJDh7rZbDNvXjY6d5a+xHv10uL27aJ7lQ4ZYsB//yudCeidd1wRG1t0ZBQa\nasb330uzqWzfrsLUqa42n2/rVj2Cg0WkpgLdutneoWbKlBwMHiydwu/ZZ91w+nTRMwV17WrE3Lk5\nAICFC12wenXRQNFqRZw+rcStWxnYv1+B0aNt98GqVVlo2VJ6L9q2dYetswdGR+di9GjpfZk0yRV7\n9xbtg+bNTVi9Wnpfvv5ahY8/tt0He/bo4OEBXL4sYPBADWA0QDAYAEPetdGIJRiLx02xAIAO2Iur\nqFPwAAoloFLh/8L3YUbfP2EOCcP033rhu31hgEpptWdwrVpmxMZK78uuXUq88YYGtnz3nR7h4SIy\nM4HOnW2/L6+/noOhQ6XOGT5cg2PHCj6bCoUCZrMZHTsaMX++9L7ExKixfHnRz6ZKBfz9t3T875Ej\nCowYYft9iYnJwkMPSe9Lx45a6PXS6zIaTcjK0iMrS4/c3PkQxTkAAD+/b2A0doebmxvUhf6D0aiR\nGV9+mZX3OlV4/33b78uuXTr4+ADXrwvo39/279azZ+egTx+pDwYNckN8fNHPZr9+RsycKfXBRx+5\nYOPGop9Nf38RO3boAQB79igxebLt9+Xrr7PQoIEZublAu3buln4ubNKkXERFSZ/N6GgNDhwo+p3R\ntq0JS5ZIn82VK9VYssT2d8aBA9L7cvKkAlFRtt+XRYuy0a6d9J3RtasW6elFvzOefdaAyZOl74wp\nU1yxY0fRfy/16pmxcaP0vmzdqsKMGbbfl+3b9ahZU8TNmwL69LH9vsyYkYP+/aX3ZcgQN1y4UPR9\n6dXLiPffl96XefNc8OWXRd8XLy8Ru3frERDgiS1b9Bg/3vb7snZtFpo0kd6H1q1t/3uR87vcXirs\nUKOdO3fi22+/xapVq0ps5+urhUple1R8PxQ2Tj/n6alBQID0hms0ts9Q5+qqQECAOq998WexCwjw\nhIsLcOdO8W18fLQICJCWVSrb7dzdXREQIP3D0Gptt1GrFZY3ytu7+Ofz8/NAQEDxzwUAXl5ulppc\nXGy3c3NzQUCA9EH18LDdJn9DRkCAJ3x9i38+X193y/MplYCt8zh4eNzP+yICJhOQKwVswKyp8Dh/\nGBnH9VCkfFv0gRQKCA3qA62GAuHhwDf1gMy8sywpVZZwVT05CO4fDJJqugkoDhV9qLK+L25uxbfx\n9Cx4X1xdi7ZTKBTQaEr3vuTXVJb3Jb+di4sCLi7e8Pb2xvPPT0OdOvWwadMm/PxzMkQxFWlpqVCr\n1dBqtXB3d4eLi9ryfF5exT+fv7/0OcnJKb6Nt3dBH6jVtttptQV94F7M9LQqVUEf+JRwJsoaNaQ+\nyM0taHP390bh7wxb7wsAaDSl/86Qntd+3xnFfaZcXBSlfF+kz6bZbL/vjNK9L9p7vi9ASf9e5Psu\nd4Ryj3z37t2LBQsW4LPPPoOPj0+JbbnZ2Xk5tJ+NRigvX4Ly7Bkoz52BKu9aee4cFLpMq6aiQgFz\naBiMEQ1himgkXTeIgKlBBEQn3LO3MnyeU1NTsGPHdmzbtgW7d+9CTo40yqlb9wHLaS4jIx+U53zT\ndlIZ+rk6qMh+FkXAaATyNnbBYBBgMEj/wTIagdxcoci6/PWF/zYYhEL3kf7TMXSoAV5e9qvV7pud\nMzIyMGzYMKxevbpUe04yfJ1XhfRzVhaU589Bde4MlGfPQHXurBSyFy8UOVG/6OICU70GBeEa0RDG\nBg1hqldf+u9wFVHZPs+ZmRnYuVM6zeXOnT9bTnNZu3Ydy2ku27RpW+x+HpVVZevnyshkks7jkpsr\nBVlODpCTI11b3ybdLt0GZGdLyzk5AtRqV6Sm5tgIwMLhd3cAFl1XEKrW7SrKypVZls3v9mD38N2w\nYQMWLlyIunXrWm776KOPEBwcbLM9w9d53U8/C2mpBeFaaDSriL9S5DAds4cnTBERMDXIC9eIhtLf\noeGV5qT8Fakyf56zsrIQF/crYmN/wI4d25GengYA8PX1RYsWkWjZshVatoxEixaRCA0Nq9Qj48rc\nz2YzkJUlhVh2trSclSXkBaFgFXZSEBaEoK1l27cVF6QFj2syyfP+qdUi1GppBKpSiXnXgIuLdHvB\npaBd/nL+xcVFzLtP0XVqtfVjFrST1nl5AW3amGwdCVhuPLczlds9+1kUobiZJIXr2TPSaDY/bG2c\naMLsH5A3ipXC1digIUwNG8EcVKtanwLRWT7Pubm5+P33PYiN3YK9e/fgypXLVut9fHzQvHkkWraM\ntARyeHjdShPIZelnUcwf0UmBqNcXBGN2tmAVjkWvC9rq9db3KQhX679zciq+jwRBhKurFDouLiI0\nGmnZ1VXMuw3QaETL+vy2rq62lotfX7OmFnq9/q4ALByu1gGoUlXNf/4MXyo3Sz+bzVDEX8nbVHy2\n0G+yZ6FISy1yP1NIKEwNCsLVmBe2om8NGV5F5eesn+fU1BQcPXoER44cxtGj0uXSpYtWbby9fdCi\nRUs0b97SEsrh4Q/Y3GGyLERRGqnpdAJ0uruvbd8mii5ISTFYQq9wgNoKUbPZ/ong5iaFnkYjXbu5\niXBzk/7Ovy58e36wFQ486/C71/qCx3BUyDnr59neGL5UJkJKMlRHDkN15BA8Lp6F4dgJqM6fhZA3\nz2w+UaWCqe4DeTs8ReSNZhvCWK+BtDsklVpV+jynpaXi2LGjeYF8CEePHsGFCxcAuAPwAOABrTYI\nDzzQEqGhTREcHIGaNetBqw1EVpaiSGjq9cWHqz02kapURcPPVgjmL+cHp5tb6f8uuL1qjvDuVpU+\nz/eDsxpRsYSMdKiOHoHq8CGojhyE+tBBKO/alKjSamGMaGS9w1NEQ5jqPlAhJ5ygysFgANLTBaSn\nAxkZAtLSBMvf6ekCMjKKjjCloNRCpwuGTtfHchtgnTh6PXD8uHQpLa1WhLu7CHd3oEYNs2XZ+rrk\n22rXdkdWViY0GunxNJpqsUsBVUL82FUnej1Ux49BfeQgVIcOQnXkEJTnz1nt/GSuUQO5XbvB0OpB\nGFs+CO9Oj+C2WwkHk1KlZDYDmZn54WkdmmlpUnCmp8OynB+sGRkFt+WflKOslEoRHh5S2Pn6iqhT\n5+5QlJbV6hxkZNzAnTuXkZh4Htevn8aNG+cgiukAMgFkws1NRNOmddGqVSO0bNkSLVu2Qv36Dcq9\nl3VAAHDrVoVv7CO6J4ZvVZWbC9XJ49KI9vBBqA8fgvLMKatTKpo9vWB4tCOMkQ/CENkKxsgHYQ4J\ntd4uFuAJcPORQ4mi9Ptj4dC0DkncFZgC0tJQaFkKUVEsW3hKe3yK8PQEgoLM8PIS8y4otFxwm6en\nCA8PEVqtdai6uJRl02qtvEs7AIBOp8Px48dw9Oghy+/IBw/GYf/+Xy330Gq1aNq0uWWHrpYtW6FB\ngwioOIQlJ8LffKsCoxHKM6ehPnIob0R7EKqTJ6yOmRXd3GBs3jJvRCsFremBevcc0bKf7092NpCc\nLFhd7tyRrlNSCv7OzFQhOdlsGZ0aDGULTkGQQlMKTxHe3sWHZuG/vb0L7uPmVjl/j9Tr9Thx4hiO\nHj2MI0eky9mzp2Eq9B9JNzc3NGnSLG+HrlZo0SISDRs2KhLI/Dw7BvtZwh2uqhKzGcoL56E6fNAy\nolUdPwohK8vSRHRxgbFps7wRrRS2poiGlWa2HWeVk1NykOYvF/67tJtu3dwALy9zMSPNoiHq7Y1C\nISsWeyrKqiorKwsnTx63jI6PHDmMM2dOwVjoxOIajQZNmzbL28taCuSOHR9Gamp2CY9M9sDvDQnD\n11mJIhRXLhca0R6C6shhKDILXreoVMLUqIlls7ExshWMjZtK2/7soKr2c04OLCPPwkFaeDR6d5Dq\ndKULUq1WRI0a0sXXV4SfX/F/+/lJt4WEVM1+dqTs7GycOnXCKpBPnz4Jg8FgaaNUKlG7dh3UqRNi\nuYSEhFqua9euA1dX2xMUUOlV1e+NsuLezs5AFKG4cd0SsurD0rUiJaWgiSDA1CACuS1bFWw+btZC\nGjZVczodkJQkIClJgdu3bQdp4dFpZmbpglSjkQLygQfMRYLz7kt+kPLtkIdGo0GrVq3RqlVry205\nOTk4ffqkZXP1hQtncOnSZfz11x8obtwRGBiUF8YhCAkJsyzXqSOFtIeHh6NeElVhHPnKRLh1C+rD\nB6x2iFLcumnVxhReN29E21oa0bZoCdGj+P9JVQS5+zkzsyBUExMFJCUJSExU5N0mWNZlZNw7TF1d\nC8KzpCAt3EZrewY2u5O7n6uL/H7Ozc3FtWtXcfVqAq5eTUBCQrxlOT4+HtevX7XahF2Yr6/vXaEs\nBbMU1qHw8fGtNGf0kgs/zxKOfOWWlQX1/n+gOrgf6vxDfK5dtWpiql0HOY/3LxjRtoyssmeDEkUp\nVPNDND9Uk5IKQjV/3b029fr7mxESYkZgoIigIBGBgWYEBBQfpNX8O5HyuLi4oG7dB1C37gM215tM\nJiQlJSIhIQFXr8YjISHesnz1agLOnTuDo0cP27yvu7tHoVCWRs/5f4eEhCIgoOZ9n92LnB/DtyIY\njVAdPgiXvXug3rsH6n//hpA3PRsgnd84p0cvy2+0hpYPQqxZU8aC7UMUgfR0WI1MExMVuHlTsBq1\n3rxZ8o5IgiCFZt26+aEqXdesWRCwQUEiAgJEe/20TWRFqVQiOLg2goNro23bR4qsF0URd+7cQULC\nlbyRc0Ewx8dL16dPn7L52K6urggOrm0VyvnBHBISilq1gnnYVDXAd9geRBHKUyfhsjdOCts//7Da\nKcrYtDlyO3aG4eFHYGz1IMzBtZ1qCCaKQGoqrDb9Wo9SC/7Ozi7+dSkUIvz9RdSrZ7aEaGCgaBWw\ngYFSqPLEWVSZCYIAf39/+Pv7W/3GXFh6elpeKCcgIeGKZVkaSSfgt99227yfUqlErVrBVjuF1ahR\nA76+NQpd+6FGjRrw8vLmKNpJ8TffclJcvpQ3so2Dy++/QXH7tmWdse4DMHTsgtxOnWFo3xGiv79s\ndZaG0Qhcvy4gPl6BhAQBV64oLMtJSSrcuCGWOOOKQiGNSvM3/dasab0ZWLqWgpf/obdN7s9zdVGZ\n+lmv1+Patat3/d58xbKcmHgDZrO5xMdQKBTw9fWFr68Uyn5+fpbl/KDOX65RI3+dL1wqeJNRZepn\nOfE3XzsQkpLg8ru0Gdnl99+gjL9iWWcKDEL2U08jt1MXGDp0grlOiIyVFmU2SzstXbkiBWp+sMbH\nSyF77Zpg8wT1CoWIWolNSOEAAAkTSURBVLWAJk3Md41SrUet/v6iXefAJKoOtFotGjSIQIMGETbX\nGwwGXL9+DdevX0NycjJSUpILXd+x+jslJRmXLl20OvFISTw8PG2MpouGduEwd3d3r/Y7ktkTw7cY\nQloq1H/+IY1s9+6B6sxpyzqztw9yHu8vbUru1AWm+g1k3YwsisDt24JVoMbHFyxfvSogN9d2fUFB\nZjz4oBmhoWaEhZkREiIiNFT6OzhYRHCwJ27d0jv4FRGRWq1GWFg4wsLCS9XebDYjPT3NKpALL9+5\nU/T2s2dPI6vQCXpK4uLiYmMUbTu869ULQW6uAHd3d2i17uU+F3dVxvDNl5UF9T/7LJuSVUcOQ8jb\n5CO6uSG3y2PI7dgFhk6dpWNrHfxhSk0FEhIUVqPXwiPY4nZg8vc3o2lTsyVQ88M1LMyM2rWlWV2I\nyPkpFAr4+PjCx8cXQL1S30+v1xcJ6sIj7Ltvv379Ok6dOlmm2jQaDdzd3eHu7gGtVpsXyh5511q4\nuxddzg/u/PvZauvMv3dX3/A1GqE6dMB6j+S8cyGLKhWMDz1sGdkaHnxImp26AmVmSuEaHy9YQjZ/\nOT5egfR02+Hq5SWdACI/WMPCCpZDQszg+QCIqCRarRZarRa1a9cp9X2MRiNSUlKKDe2srAwkJ6dC\np9PlXTKh1+uh0+mQlJQInU6H3ELnnr+/2u8O6sIhbyvIrf8TkL/s6+sLb2+f+66ptKpP+JrN1nsk\n//WnZY9kURBgbNYChg6dYOjUGblt28PeqZWTgyKbhfODNT5ewJ07tv8Hp9VKI9VHHpHCVBrBFmwa\n9va2a5lERPekUqkQEBCAgIAAm+tLs8OVwWCAXq+zBHTBcmbe33rLsvV66zDPb5OamorMzIxS/+59\nN4VCga+++haPPda9XPcvq6obvqJYsEfy73uK7pFcrz5yBg+R9kh+tCPEGn52eVq9Hjh/XoEzZxQ4\nezb/WonLlwWYzUVHry4uIkJCRDRvbiwSrKGh0vGu3MeBiKoatVoNb28fu442RVFEbm6uzXC2DvOi\n6wEgIqKh3Wq5lyoVvoqkRGlUm79HckK8ZZ2pVjCyhzyD3A6dYOjYGeYybGKxJTMTOHs2P2CVlqBN\nSBCKzKNao4YZbdqYUK+eFKjSCFbaRFyzplitZqMhIqoogiDA1dUVrq6uqGGnAVVFcerwFdJSof7j\nd2lT8u+/We+R7OuLnH4DpLDt1AWmevXLtUdyaipw5owS584VjGbPnlXg2rWiiVmzphkdOpgQEWFG\nRIQZDRtK1/7+FX4oNRERORHnC19RhNvC+cCOWPgdOFCwR7JWi9zHukt7JHfsJO2RXIYh5e3bQqHN\nxAWbjG/eLPoYwcFmdOlitISrdDHB19dur5KIiKow5wtfnQ7un3wIGI0wPPwIDB07S5cHH7rnHLai\nCNy8Kdz1e6x0sbXDU2ioGd27G/NGsdKItkEDM7y8KurFERFRdeB84evhgTv7j8M/LBBpetunXhNF\n6XSJ1qNY6XfZtDTrTc+CICI8XESbNgarzcX165vh7u6IF0RERNWN84UvADEgAHB3hzkzAwkJgtVe\nxfnLd09Fp1RKx8N26GC22lxcr56Zk58TEZFDOV34iiIwfbor/v0XOHXKA1lZ1iGrVouoX99cZKen\nBx4wc/o5IiKqFJwufPV6YMMGNbKzYQnZ/IBt2NCE8HDOnENERJWb08WUuztw4kQmAgM9kZzME/4T\nEZHzccrTO6jVDp/XgIiIyG6cMnyJiIicGcOXiIjIwRi+REREDsbwJSIicjCGLxERkYMxfImIiByM\n4UtERORgDF8iIiIHY/gSERE5GMOXiIjIwRi+REREDiaIoijKXQQREVF1wpEvERGRgzF8iYiIHIzh\nS0RE5GAMXyIiIgdj+BIRETkYw5eIiMjBnC5833//fTz99NMYOnQojh49Knc5VdqcOXPw9NNPY/Dg\nwfj555/lLqdKy87ORvfu3bFp0ya5S6mytmzZgieeeAJPPvkk4uLi5C6nStLpdBg/fjyioqIwdOhQ\n7N27V+6SKi2V3AWUxT///IMrV65gw4YNuHDhAqZMmYINGzbIXVaVtG/fPpw7dw4bNmxASkoKBg0a\nhJ49e8pdVpW1dOlSeHv/f3v398r6H8Bx/LkzubBxzDJaIblRSigXWHJBLlz7kRa3cqVc0FKUq7lS\nKAp/gLZwI0pZuZgr5UJRXGExy8evxgU6d6fOt9x8a3vbp9fjbrt61i5ee38+n7bfpjNsy7IslpaW\niEajpNNpFhYW6OjoMJ1lO5ubm1RXVzM+Ps7d3R3Dw8Ps7u6azvqRcmp84/E4nZ2dANTU1PD09MTr\n6ytut9twmf00NzdTX18PQFFREW9vb3x+fuJ0Og2X2c/l5SUXFxcagwyKx+O0tLTgdrtxu93Mzs6a\nTrIlj8fD+fk5AM/Pz3g8HsNFP1dOXXZOpVL/fJglJSXc398bLLIvp9NJQUEBAJFIhPb2dg1vhoTD\nYSYnJ01n2Nr19TXv7++MjIwwODhIPB43nWRLPT09JBIJurq6CAaDTExMmE76sXLq5Ptf+mXMzNvf\n3ycSibC+vm46xZa2trZoaGigoqLCdIrtPT4+sri4SCKRYGhoiIODAxwOh+ksW9ne3sbv97O2tsbZ\n2RmhUEjPMXwjp8bX5/ORSqX+vk4mk5SWlhossrfDw0OWl5dZXV2lsLDQdI4txWIxrq6uiMVi3N7e\nkp+fT3l5Oa2trabTbMXr9dLY2EheXh6VlZW4XC4eHh7wer2m02zl+PiYQCAAQG1tLclkUrervpFT\nl53b2trY29sD4PT0FJ/Pp/u9GfLy8sLc3BwrKysUFxebzrGt+fl5otEoGxsb9Pb2Mjo6quHNgEAg\nwNHREV9fX1iWRTqd1v3IDKiqquLk5ASAm5sbXC6XhvcbOXXybWpqoq6ujoGBARwOB9PT06aTbGtn\nZwfLshgbG/v7Xjgcxu/3G6wS+X/Kysro7u6mr68PgKmpKX79yqmzR07o7+8nFAoRDAb5+PhgZmbG\ndNKPpb8UFBERyTJ99RMREckyja+IiEiWaXxFRESyTOMrIiKSZRpfERGRLNP4ioiIZJnGV0REJMs0\nviIiIln2BzQKNGAGnBgwAAAAAElFTkSuQmCC\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f7a18df6b50\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], "source": [ - "# Train our variables.\n", - "\n", - "# numpy is used for its asscalar() function.\n", - "import numpy as np\n", - "\n", - "num_training_steps = 10\n", - "\n", - "def train_model(inputs, labels, wb, optimizer, num_training_steps):\n", - " loss_at_step = []\n", - " w_at_step = []\n", - " b_at_step = []\n", - " for step_num in range(num_training_steps):\n", - " loss_at_step.append(run_step(inputs, labels))\n", - " w, b = wb.variables\n", - " w_at_step.append(np.asscalar(w.numpy()))\n", - " b_at_step.append(np.asscalar(b.numpy()))\n", - "\n", - " print(w_at_step)\n", - " t = range(0, num_training_steps)\n", - " plt.plot(t, loss_at_step, 'k',\n", - " t, w_at_step, 'r',\n", - " t, [true_w] * num_training_steps, 'r--',\n", - " t, b_at_step, 'b',\n", - " t, [true_b] * num_training_steps, 'b--')\n", - " plt.legend(['loss', 'w estimate', 'w true', 'b estimate', 'b true'])\n", - " plt.show()\n", + "## Next Steps\n", "\n", - "train_model(inputs, labels, wb, optimizer, num_training_steps)" + "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." ] } ], @@ -572,7 +312,7 @@ "colab": { "collapsed_sections": [], "default_view": {}, - "name": "Eager Execution Tutorial: Working with Gradients", + "name": "Automatic Differentiation", "provenance": [], "version": "0.3.2", "views": {} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..84f1d031d40604ae029e8a8347474950ee01b38a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "k2o3TTG4TFpt" + }, + "source": [ + "# Training Models\n", + "\n", + "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n", + "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n", + "\n", + "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3LXMVuV0VhDr" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "PJ64L90aVir3" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager # Shorthand for some symbols" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eMAWbDJFVmMk" + }, + "source": [ + "## Variables\n", + "\n", + "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "VkJwtLS_Jbn8" + }, + "outputs": [], + "source": [ + "# Using python state\n", + "x = tf.zeros([10, 10])\n", + "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n", + " # value of x\n", + "print(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wfneTXy7JcUz" + }, + "source": [ + "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n", + "\n", + "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "itxmrMil6DQi" + }, + "outputs": [], + "source": [ + "v = tfe.Variable(1.0)\n", + "assert v.numpy() == 1.0\n", + "\n", + "# Re-assign the value\n", + "v.assign(3.0)\n", + "assert v.numpy() == 3.0\n", + "\n", + "# Use `v` in a TensorFlow operation like tf.square() and reassign\n", + "v.assign(tf.square(v))\n", + "assert v.numpy() == 9.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-paSaeq1JzwC" + }, + "source": [ + "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n", + "\n", + "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BMiFcDzE7Qu3" + }, + "source": [ + "## Example: Fitting a linear model\n", + "\n", + "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n", + "\n", + "1. Define the model.\n", + "2. Define a loss function.\n", + "3. Obtain training data.\n", + "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n", + "\n", + "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "gFzH64Jn9PIm" + }, + "source": [ + "### Define the model\n", + "\n", + "Let's define a simple class to encapsulate the variables and the computation." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "_WRu7Pze7wk8" + }, + "outputs": [], + "source": [ + "class Model(object):\n", + " def __init__(self):\n", + " # Initialize variable to (5.0, 0.0)\n", + " # In practice, these should be initialized to random values.\n", + " self.W = tfe.Variable(5.0)\n", + " self.b = tfe.Variable(0.0)\n", + " \n", + " def __call__(self, x):\n", + " return self.W * x + self.b\n", + " \n", + "model = Model()\n", + "\n", + "assert model(3.0).numpy() == 15.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xa6j_yXa-j79" + }, + "source": [ + "### Define a loss function\n", + "\n", + "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Y0ysUFGY924U" + }, + "outputs": [], + "source": [ + "def loss(predicted_y, desired_y):\n", + " return tf.reduce_mean(tf.square(predicted_y - desired_y))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qutT_fkl_CBc" + }, + "source": [ + "### Obtain training data\n", + "\n", + "Let's synthesize the training data with some noise." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "gxPTb-kt_N5m" + }, + "outputs": [], + "source": [ + "TRUE_W = 3.0\n", + "TRUE_b = 2.0\n", + "NUM_EXAMPLES = 1000\n", + "\n", + "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "outputs = inputs * TRUE_W + TRUE_b + noise" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-50nq-wPBsAW" + }, + "source": [ + "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 293 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1210, + "status": "ok", + "timestamp": 1527005898290, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "_eb83LtrB4nt", + "outputId": "3873f508-72fb-41e7-a7f5-3f513deefe38" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEDCAYAAAA2k7/eAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXlgU1X2xz/pAhRautCWUsCwWVlcUHHGBUFQcSg7uM8P\nFLUICo4VpygObihI3UdmUHBB0IGZQbEgFNGCqKgMolV2pKylCy1pukDp+n5/3LxmaUsDTUjSns8/\nbZKXd09C+b7zvvfccw2apmkIgiAITR4/TwcgCIIgnB9E8AVBEJoJIviCIAjNBBF8QRCEZoIIviAI\nQjNBBF8QBKGZENDYE+Tk5JCUlER+fj7+/v7cdtttTJgwgcLCQhITEzl27BidOnXijTfeICQkxBUx\nC4IgCOeAobF1+Hl5eeTn59OrVy9OnjzJ2LFj+ec//8mnn35KWFgYCQkJLFy4kKKiIh5//HFXxS0I\ngiCcJY22dKKioujVqxcAbdq0oXv37uTm5pKWlsaYMWMAGDNmDF999VVjhxIEQRAagUs9/MzMTPbs\n2cNll13GiRMniIyMBNRFoaCgwJVDCYIgCGeJywT/5MmTPPLII8ycOZM2bdpgMBhcdWpBEATBBbhE\n8CsrK3nkkUcYNWoUN910EwDt2rUjPz8fUD5/REREg+eRtj6CIAjuo9FVOgAzZ86kR48e3HPPPTXP\nDR48mE8//ZRJkyaxcuVKbrzxxgbPYzAYyMsrdkVIbiUqKkTidCESp2vxhTh9IUbwrTidodGCv23b\nNlavXk1cXByjR4/GYDCQmJhIQkICjz76KJ988gmxsbG8+eabjR1KEARBaASNFvwrr7yS3bt31/na\n4sWLG3t6QRAEwUXISltBEIRmggi+IAhCM0EEXxAEoZkggi8IgtBMEMEXBEFoJojgC4IgNBNE8AVB\nEJoJIviCIAjNBBF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCaI4AuCIDQTRPAFQRCa\nCSL4giAIzQQRfEEQhLOk0GTi84R7+XbIDXyecA+FBSZPh+QULtnTVhAEoTnx7YzHuDflUwyAlv4z\nizEwfNFiT4fVIJLhC4IgnCWhhw9hsPxusDz2BVwi+DNnzuTaa69lxIgRNc/Nnz+fAQMGMGbMGMaM\nGcM333zjiqEEQRA8TqHRiGb5XQMKjV08GI3zuMTSGTt2LOPHjycpKcnu+YkTJzJx4kRXDCEIguA1\nXJ/8OosxEHr4EIXGLlyf/JqnQ3IKlwh+v379OHbsWK3nNU2r42hBEATfJjQ8wic8e0fc6uF//PHH\njBo1iqeeeori4mJ3DiUIgiA0gNsE/+677+arr74iJSWFyMhI5s6d666hBEEQXMLRjAwW9u3FWmN7\nFvbtxeGMDE+H5FLcVpYZERFR8/vtt9/O5MmTnXpfVFSIu0JyKRKna5E4XYsvxOmNMb53xQhmZh1T\n5Zalx5h3ww08cfSop8NyGS4TfEe/Pi8vj6ioKAC+/PJL4uLinDpPXp73Wz9RUSESpwuROF2LL8Tp\nTTEWmkx8O+MxQg8fIjory67cMtZk8po4z4SzF0+XCP706dPZsmULZrOZG264gWnTprFlyxZ2796N\nn58fHTt25Pnnn3fFUIIgCC7FdhHVXFSZpcHyM8vGqWgKuETwX3311VrPjRs3zhWnFgRBcCu2i6ju\nBp4JDKR7QACZ4RH839dfezAy1yMrbQVBaNbYLqK6AOgaP4L4w7lMSt+NsXt3T4bmcqSXjiAIzRpf\nXUR1LojgC4LQrPHVRVTnglg6giA0WXy1jbG7kAxfEIQmi6+2MXYXIviCIDQZbGvqC41Ggg9k+GQb\nY3chgi8IQpPBMaOfE9vRrq7eV9oYuwsRfEEQfB49s/dbn2qX0V8QEcHiq/7odAWOyWRmxoyNHD7c\nFqOxkPffHwX4uzv884YIviAIPk2hycS/B1/HJVnH2In9StnK7heelWc/Y8ZGUlLGAwbS0zWmTFnO\n/PnD3RK3JxDBFwTBJzmakUHquOFE5mTTpbqaAcD1wDygQ1AQ1UOGnnVN/eHDbcHmHuHgwWDXBu1h\nRPAFQfBJUscNt3a2BJYDdwG9gRNDhp5TNY7RWEh6uvUeoWvXEhdG7HlE8AVB8BkKTSY2PDqVgB+/\nI8ZstvPrg1HCvz22I3ec42rZ5OTBwFKLh1/EggUjqapyTezegAi+IAhejz4pq23aQBuzmWHAAuz9\n+t8CA8mPH8Edya8RGn5uXS7tu7w3vS1aRfAFQfBq9ElZR/vmbuBZwGgwkN0hlqEr19C5a7dGjdXU\nJ22ltYIgCF7NtzMe4xKL2IPVvrkAuAgwjBzDpPTdjRZ7aPqTtiL4giB4NaGHD1GC1WDR7ZvktqFs\nbH85r2eMJiHhUwoKzI0ey2gstBtJJm0FQRDciF5u2anARGZ4BC169eIBlI3TBsuk7MbNPJ60Sdkv\nuQa279CApSxaNOasx7NdbNWhw0mGDn2P7OxImbQVBEFwF/rEbNba1cysqKjZSPyF6mo+GzWW0MOH\nOGHsUjMp62i/HD7cttZK2eTkwYSHh51xXEffftSopaxffyMAERHes/euKxDBFwTBK9D74HwO9u0R\nCs3E11FT71gzbzQW1RJvZ7L+ui4cTRWXCP7MmTP5+uuvadeuHatXrwagsLCQxMREjh07RqdOnXjj\njTcICXFuZ3VBEJo+u7Zt48vRQ+ladpqDBgNRrVtjAIqxL7fMrKfE0rFmPjl5EHfcsY2zFe+6LhxN\nFZcI/tixYxk/fjxJSUk1zy1cuJBrrrmGhIQEFi5cyDvvvMPjjz/uiuEEQWgCfDkmntllp5XMahpP\nnzyJBsQDy4BC/MhoFcawxcvqfH94eFit7P1cxLuuC0dTxSWC369fP44dO2b3XFpaGh999BEAY8aM\nYfz48SL4giCwa9s20sbG0+10qZ110x14JSwMf8LZZO7HKt6G0+Hs/8dSFi3q69S5z0W867pwNMS5\nzBV4A27z8E0mE5GRkQBERUVRUFDgrqEEQfAQZyN8+qTs6VUruUjTOIS9dZMNxAwczN8Pjyc9fXTN\n+87GUz8X8T4XzmWuwBvwuknbqCjf8PklTtcicbqW8xXn1Kmf2wlfy5bL+fe/76p1nPnECVbc1J/e\nmZmUAEOBpajOltHAAYOBsBtvZMz7i1g3ZZ2dLRMXV4qfXxUPPZTKwYPBdO1azIIF8UREnJ+Muq7v\nMisrHNu5gqyscJ/423Cb4Ldr1478/HwiIyPJy8sjIsK53ha+UAIVFeUbpVoSp2uROGuzb18QtsK3\nb18QeXnFtTL/+PIVzMjMtGuN0BUYDsxqFcRfjuQCUFEFs2dfT1mZ1ZaZPXsQ99+/qubCsnWrRlnZ\nUubNG+R2W6W+7zI21oTt/UlsbIFH/zacvdi4TPA1+65DDB48mE8//ZRJkyaxcuVKbrzxRlcNJQiC\nl1DfJKlueRgwcUH6FPBbb+fXtwF+A7a0CuLmVevszlmXLVNX6aQnbRVfneh1ieBPnz6dLVu2YDab\nueGGG5g2bRqTJk3iL3/5C5988gmxsbG8+eabrhhKEAQvoi7hKzSZCN34D17hUfIoYS4VLKu29+t3\nderEHWnfOd3Vsq4Liyfr58/XXIGrcYngv/rqq3U+v3jxYlecXhAEL8VW+ApNJj57YAKl337NDSgp\n7mD5GY+yccotO1FNfn8RuXknSUhY6ZQlU9eFJSlpQ7Opn3cVXjdpKwiCb/LtjMeI/fZr7sKayb9k\n+RkG3AkstuxEFRYRwr33fe60JaNfWPS5gTvu2Far742v2CqeRARfEASnqasMs9BUwOJxCVyanU4+\nUIgSeAOqffFLQPuwMAwDB3N98muYTGamTv2c9evB1pLJyPBvMOOvr++NyWQmKcn36uLPNyL4giA4\nja3g/pqeT+DqP3BJ9UHmY83ql6E2J9GAn4Hc9pezLCqRblRzHX4251iGrbNvMh1mx44nOVPGX59v\n76t18ecbEXxBEJxGF1wDJ5jM5fyjOrNWs7NC4G3gd/zZenUS3/74ol0LY6toK2c/KKiCIUPgwIE4\nsrLOPAnrOHl7/PguDhzowaZNucDnqE488U26AVpjEMEXBKEW9a2g7djhGMb0eG5mHa3R6mx2to6r\nWcUPwCrC9uRbXjEDqaxfD+HhO4GBNe8oKytl69Z8evUKtjuTPglr36++nPbtnyY39yrgJFlZUxg7\ndgFm85PY3mMYjZVn/BzNFRF8QRBqoSySEcA60tPD2bp1CY/cW0nv1GfpDeQBuWDX7Kwc2A2s4l+W\nV04C+ZbfU4E7KS01UFqqERT0DKWlLYGZVFcbyMrSqK5+gVGjate2O9o1YWGvACNrYi0o6ITtPUZY\n2GmSk2+u873N3eoRwReEZsDZZrrKElkH3Ik/WxmSNYaCOdX0B0qAB4BVwNNAJ9QFIB9Y1vZeKNoO\n/Aj8iWuuWU6LFktZvx5KS62iXFraDyXS1ucKC411irGjbw/tsL0TCA8/Smmp9fHAgQE1n6059bp3\nBhF8QWgG1JXpnqk1gdFYyK/pp4nnEv7ATiKBUGCA5edyIALoDOwliNfIJCzsM7744g/MmfOz5Zyr\nSU4eTnh4GAkJn5KSYmv8nLT8tBXuzDpjd/Ttr7mmmhYtrHcCM2eOYs6cule9Nqde984ggi8ITRDH\njP7AgTY01JrgxImX+emnYsrKuhKifcxf2EAw0BdqGp6lAnehWiMUA7tpxRtsB8Ixm1vx7LPf0aJF\na8s41nYrtgunjh/fRVbWFEs8y/DzKyYm5gQrV1ptGltqL7q6pdbdyaJFRiff27xr9UXwBaEJ4ijm\nsbFzcJwQdbQ7Nm8uAm0ioxjFzeykEHgCa06+HNCnVf8H/MKlrGUssMvyTDybNy+kqOhBHD1z2xW5\nBQVXMmvWOvbtC8JorCQ5Of6M9lJj2hj4agsEdyGCLwhNEEcxj4jowlVXnbk1gb92kkR6MM/yzCrs\nnfM2wK/Atxh4mXuB14A1qJ6X6hynToXSkGceHh7Gv/99l090Hm1qiOALQhNEedcFqInXlvz++06O\nHGmFn18nOnSoAuztjuh2+7k07Q36Y5XrEuzLLb8DXmYDMAhQ1TLV1aUUFS1BOfoltG5toqiocZ65\nlFK6DxF8QfBRMjIOM27cKgoKOhEefpSVK0fRtavyspOTB7N16wKyslR9elnZGMrKlgHxpKauZfPm\nLwgOziW8bTsuP/4SxvTddMZe5IcCs4COwC78mc8e4D+Wo4rp1CmW7t0rSUmZgC7w/fq9xZ49cy0x\nZTJzZm1fXm+toCyd2oIupZTuQwRfEFzI+cxOx41bVSPopaUaI0c+w9VX9yArK5zYWBMREUa7lasQ\nhFoDO4PiIhN9i/7ENVk/0Qb4G/AKcDvKq28DbAPKgHXczKqaupyLgRGAxv79s3jvvTuxnRQtLw+0\ni2nOnKW1JlQbEnQppXQfIviC4EKczU6dvTDUd5zJZCYnR8O2nUBeXsuasdUuTHOxN2X2AH3wYz/j\nuYgYNHoDpZYjwlBVOCGWM5qBv/Moyqu3LacEMHD6tCrBtP18Q4akoZorpALBbNqUQ0GB2e6z2Qt6\nIZs25TJkSFrN55NSSvchgi8ILsTZ7NTZC4PjcWVl7wGQlpZLdfVMbNsJaFqE3dgFBbG0ajWL06fb\no0Q4GH/SmcooWgNXo8wZ/Qy3AWuBTGB/y1DSus7B//dDVFUtAyqBY8Bky/mV+Dt+PiXWa8HSJNls\nHk5S0lK71saHDlUCHwPDgLWYzY+Tnm79HqSU0n2I4AuCC3E2O63vwmCb0cfE5PH99/arUX/80Q+z\neSLUallWTnCwieJi69iqdcFsYmJe4FTO10xhA3HAPuBFrEK/BJiLMmwOGAL5vPsLxPVpz6fJg3n0\n0dWkpgLkoMR+Hcrw2QU8SEzMJ3YtjWfOvJJNm/6H2Xzmjpb6pC+0q3WslFK6DxF8QXAhzmanHTpk\nk57+L5SBUkSHDrZ7waoeNtAeVd9uFfGiohzL744ty/IpKTlOy5azKC/vhqYFoaZdDbSqzOMeNjCX\nusstw1E9cF7yv4viqo9hv4Hd+zVSU5+ksjIEg8FAy5a5lJXNQ9PiwLLQKiZmPgZDJCkp92N7pzJw\noL/dqlr9oud4kevS5UKMxsI6jxXcgwi+ILgQ57PTQLDbG+o9TCazpc1viuX1AcD1wDwgBmhBdXWU\n5fjrUHl5NKqTzd1o2mbKyu5CtTK7k0BWkMjtdM2HFtRfbvk9MI+HoeoKm6MKKS9vA1wClHD69KXA\nBJt3LeP06WNkZ3fA8U7l3/++krouenXd/Yh9c35xu+APHjyY4OBg/Pz8CAgIYMWKFe4eUhC8nuzs\nSGyFMjs7khkzNmI2P255vgBVUdMH8ENJdjzqYvAekAHMwX4dbIjl8XX48xBTeJvLLM9uwb7c8kng\nAiCdIBaRALyB/YYka1G1O/r5P8T+viAEaFeniNd30bMV97i4UmbPHiT2zXnG7YJvMBhYunQpoaGh\n7h5KEHwG+4VRbTh+fCdVVRdhFdV1wAzL4+G0aJFEefkBVG/KAqAH9gIcDBThzxb+j6uJRTnt+j1E\nf+ApoCfKfd9LIPN4CUjEKub6VuOlqEla2/PnYX9fUMw111Rb2hA7l6HbintUVIistPUAbhd8TdOo\nrq529zCC4HHOpgbfcWFUVtYITKZZwDisjQysgtuyZSTl5UlYBVffHlx/vJcW/EoiH2EE2qLuC/Qz\nhANGlNjPYxgwGnXXsAw4hf1W48ss77I9fy7qziIAP78sbrklnDfeGC4Zuo9xXjL8+++/H4PBwB13\n3MHtt9/u7iEFwSMkJq4hNbUt4E96egDl5Z/z4Yf/V+ex4eFhREf3tlsYdfp0F5RfHw38jvLvwwGN\n4uJg7DPuzsALQCcCWMOdfEJHqJmYreuSsA94jf0og8d2/mA2yh5S7REgwTKOyvZbtjxAWdlUoAug\nMWKErHz1Vdwu+MuXLycqKgqTycTEiRPp1q0b/fr1q/f4qKgQd4fkEiRO1+LNcZ44Yeahh1I5eDCY\nrl2LWbAgnoiIsFrHfPVVNqA6RYLGjz++WvO5Tpww88ADKWzapAF5DBgQRocO2PnfaguRGTaP56E8\n/BIgG3v5Pgp0pxWf8hc+IRi4FPtLQk9Urm4GdhDCAn4BuqPyfNsjL0c1QHseqEB1vDcAd9Kp0zx+\n/fVxpkxJZd++X8jP38vhw0amTl1d5/dwNnjzv7ktvhKnM7hd8KOiogCIiIjg5ptvZvv27WcUfF/w\n9XzFf5Q4XUNCwqqaUsmtW4P57rt/sHHjBMLDw2r62eTktKO6ugu2QlpcHMK+fUctG4Cssus5k5Ky\nBPgNeBWIRAl4H+yFuDd6GwNYgLVBcQkBHGIqM7kQ5ei3pnb1zS6gCEjmJ9Qq226Wcxc5HKkvv7rC\n8rt1nLCwzlRV+TN//nASElaSnj6DzEx9Edi5Z/re/m+u40txOoNbBb+0tJTq6mratGnDqVOn+O67\n75g6dao7hxSEs8IZ3z0jwx94B1XXspOsrF4MGrSEjRsn2PWzUatHrUJaWdmKQYOWEh3d27K61FbM\nNVT3Guvyp8DAX6ioGIO9ZBuAdJSFcyd+7Gc0/ehOUU0bYw1l7tyLtQ/O98AxgviI7aisvrvlqF7A\nXtQU7oVAS9RkrS781cDdNWfu3n1pzfcgPW58H7cKfn5+PlOnTsVgMFBVVcWIESPo37+/O4cUhLPC\nmRYHJtNhVCHjcvQtQbKyNJKSllJQYFuHPgyVscehWo/dR1bWf8nK8kdNehage/JwANs+OFBJUFB7\n2rV7kZycSNS0613AZtQdwAECuZ1EVnARqqmZ7eWjI+oeIBz4BQMvk4Ta+1XP6kNRtf0VwHPAEWAp\nEAXMx8+vkN69+9K5cxHwHtnZkbJdYBPErYLfuXNnUlJS3DmEIDQKZ7LWdu3iLJOr9hOnq1ZVomm7\nsWb1eulxgeXnRqADavJ1OJCEqoTRd4PNRU3QLgBKKSp63tJL/hmU4P8XmI4BE4MZSD920hvV0cYP\ne1PGhDJqnuIOVOb+PKp/zjKgHFWRU2zzGb5HZfnqDDExc9mwwdrKWL/zueOObTV3PrJIyveRlbZC\ns8aZrLVbt5Ns365qz21lVrUviEOJahCqQUEgyjL5K9a+MwuAKajFSvYNz2AkyqefZxnNgLJdDgNt\n8eN1pjKDi6gkFHUPEYqa2l2GtbNlJvAmM1C1OXrzhDCUPbPaMsYCwsJ2YzYPx/Hi1a5dnN1nru/O\nR6pzfBsRfKFZo2etGRmtMZn2kZFhJCHhUzsv33qMH/v3P83p00bUQqRhwHpU24NiVLZ+P8qqWYeq\naTegxHYJyj5xXK2q/x6MkvA2qJLMSIJYwiNs4VpqbyLeFdiBWob1Bd1YxR2oOwioPX2rHgcGmtiy\nZQJJSUvZtCnHIvzqmG7dTtl9L+LXN01E8IVmjb5w6J57PmbHji5kZYWwY0cOP/zwHuXlFwD5XHNN\nMG+8MYJZs75jx47nsbY+eB3ohxL7oSj/Xq+jz0ZZKmGWn0eAVjiuVlVoKKPmYcBAC/L4Cw/QAlUh\nb9s8Qd9E/DCQg4G5bEU1MzuFaocQgrJwnkDdfRxH1c8vY8CAkJrPW1BgJimpfntG/PqmiQi+0OQ4\n212nTCYzX32VhbWG/l8cP/4Mutilpi6jRYuNZGWFY93cIwNV6a7L8XxUhYttHf0ylKWi96XRPfVi\nlOtegrJgwlGZfSEB/I9HeYCXUFOqtvcDbVDSvhmYxzxURq8Bn6CqbabYjP0CMBbdVoqN3cE//zm+\n5jM3tEJW/PqmiQi+0OQ42z1RZ8zYSEVFP6zyGoKj9ZKSchz4BiW5T6KyedvVqi+gJktt31cIvI+q\njLH11FehBD8Y/QLhzxbGE04HrF1yjlG7q2UJLfgHPwArLec5iTJ42tuN3bZtB6677hNLtY2Z5OTx\ntS56Z7owSsuEpokIvtDkcPSfMzL87TbpePLJK5k792ebTUayUTX2+i5MjguTTKiKmj4o774QVSrp\n2Opgj8P7TqKmWG1ragpQ9fVRwDEMHORmRtGXHXRB1ebonW3uRuX/eqOFDYRyqtfrxBauo23bzhQV\n7aBduzgOHy6nqGgnaq5AjT1oUIsGBVs2C29+iOALTQ6r/1wIrGXv3kPs2KGqY9LTNdasmUVl5eya\n15V4ZwMXofZvLUJZNsrDV4+fw75LDdiLeybKdHkS1aYsFHjA8vM9VK+aXqhFVOpcLXmTKXSnDfZe\n/RKsPStPAj8Bb/MhsbGZpG+6tdbn7dv3LYqKpqAvuwoK+onk5IRaxzkiE7PNDz9PByAIrsRkMlNe\nXkFY2AcEBLwCDKWiwr7LTGVlZ8vjVNRkaxFqknMsSozboiY6W6BE+yrss/k+KL9cL4FcjppwDUDZ\nQS1R+XmY5fho1H+1/6F8/+W04s88yqNcBfzB4ewRqPqeA0ApLXibn4AJhIZ2r/Mzq5LKcJTFNJKe\nPS8/45yFjtFYiLrEgEzMNg8kwxe8nrOZhH300S9Yt05tuWetbdFw3A5Q/QxGTWr2xl5y+6Hq4/X3\nV1PbqgkDLkbZKDp6q4IdDsf/BDwGrCKA9fyFJfRE4xDKvsHh6L3Ad0AyM7Dtf1lYmFHnZ7auE1DH\nXXjh6TN9nTXIxGzzQwRf8HrOxmv+8UfbLvB6bcsAVHVMIcq6Kbc8PoaycRzr1k/avL8Y5bnvRmX9\nB1CLqqC21/+75Zi7gVmoydTjwP34cZw7uI8LqKpZLfsAyux5DGsPnP8BvxDLWraj7gqWoyZ9Aykp\naUtBgbnWxc5RuBcsGElVVcPfq0zMNj9E8AWv5+y8Zj1710V4p817v0cJcg9U1h2E6lLZBiW90ZZj\nJlse630oo4CHULZJAaqRWm/LuZdg7SMfClwLfInK9A1AJa14iGmspSWq4YFt82Mj8E/LmX/AwFv8\nRHT0ajiul4BqqN2n/CkqaklS0sZaIu0o3BERvtHhUTj/iOALXo+zi4BMJjMtWxZjbTlcSfv2p+jQ\noYqYmFOsWxeNmjgNQVXbPIG1cuYN1H+HauADVNOx6Vjl+VUgFtXorA/KytmL/cbeT6IuAN2Bv+HP\nVhKYRDBVhKHW49ree8SiMv1iYBVhFF74MqN676CkJJy0tGVAlkMMS2RiVWgUIviC1+Ho2c+ceSWO\nXrPtMR06ZAOB/PCDH2ZzT/SOM4GBc7jiigt4440refTRNahM3LZ2XpffdcCzNs/PcXjdgLqABKP6\n4kRZXg/HWk+Ti6rq6QQYaM10pvI6ccAh1KXAcQeqXagcftfVf+XzVbNqPv+QIWmoLQhXO8QQjtFo\nbvwXLDRbRPAFr8Pesy9g69YFREf3tpuwTUhYaXPMv7AX8uXAXVRUXEpqan9++WW+peVwFUpiQdkx\noGrsq7AX1hhqL3tqgbXR2WzUHMCtKBtHb5u8kAC+ZzjzuAhqvPo4y1nuxtp4YR/wHZEcaD+Jrz+c\nbPf5rXc09s3aYmN3kJw8HkE4V0TwBa/D3rNfR1bWk2RlWSds580bxKZNucBnqLr2AOBDy/GjUR0r\n56JWn75ETs5L2Fe5t8Bq59S1+2suKovXNwjMBfoC/0JZOlGoOv3lqHmAUYCBEN5iCjsJwbbxMDxt\n+WlEratNAl7hCoYOfYCvLRuB22Jt1uaPyTSXdu3i6NbtVJ2rZQXhbBDBF7wG3aZRu0Ppq17bYJt9\n792rcdllb1NW9kfURGknVL2LrdeeidqnNQJrWwMsP09TO6PvgrU12V7L43hUnX4R9nbPMlRWfxKV\nvz+PH/uJJ5o+VDAX1SvT9uzdUReAjqgan9dYQlBQFR9+OK7O70GqZwR3IYIveA22Vg5otG37EuXl\nJzl9Wm8ZUMC+fXuprn4Re4G3ldeLUNOhQ1HevGPVTjFqktb2Ob2RgV4FvxtlEd2Ftbe8fv5y1F3E\nt0B/gunNFPbQBWtdTrHD2fcAJ4C5PI1a2KURGvqi6744QXASEXzBa3AsvywpiaK6uhxYCORjMBRS\nXd0fewFuR+3e7yFY+9G/i/1WIUUoF/0pVO69H2XLrEb5+WGoCdqXUP89CrHtUaNkPRQD2VxPF66h\niDhUBX6X0UVuAAAgAElEQVRLyxHxWHtiHgR+IJhv+NYS0xLgd/r0EWtGOP9IawXB45w4YSYhYSUH\nDuzFdql/dfUhVOVLS+AhNK071kVSYF3s9CyqNn4Z1lYJuhV0G9bSS/189wDXAPcB/ijhH2455/2o\n7cCfQOXl01F2zyrURSKHAJ5kEg9yDUX0Rjn8D6LuJf6Gala8C7WI6oer/8qivbsIC/sSVc4ZCEzn\nxIm62yQIgjtxe4b/zTffMGfOHDRNY9y4cUyaNMndQwpegG3ZZExMHgZDJdnZHepsjfDQQ6kWK8ex\nX/x0rJt+L0fVzt+OypL1TUNOowTeH+XNL0CJ/SFUZh6GyvR160evrNmCEnRQUv0qtdsi2/aoAX/2\ncieP0Qkl246LqK6wRP078D3h7G8/hc/evIvw8DAGDowmJcW6w5T0rRE8gVsFv7q6mtmzZ7N48WKi\no6O59dZbufHGG+neXbKbpo6jH6+EfDTp6Rrl5e/QokXrmjr7I0d0K0fvF78ENXG6DjWRql8AilCZ\nfABqQrYUJdIXUbsscwLwIsryse1cqS+gMqIyeX3B1KWoUk1be2h/zWMD+dxOEp1Qa2mzqb2Iapcl\nor/zHPA05GrMmbOURYuM0rdG8ArcKvi//fYbRqORjh07AjBs2DDS0tJE8JsBGRn+WCtfirGVx82b\n8ygq6g74k54eQIcOv6AmQnWhPWY51rZ0ch5KmF8GrkZZO9NRG4E4ZubBKHHvbvndtsGZXrnTEmuZ\nZXeUXD+OtavNVmASBhYxhme4kZyada/hqBoix0VUR4BlrETdbahY9JWxUnkjeANuFfzc3Fw6dOhQ\n87h9+/Zs377dnUMKHka3cvbu3U99PeSLiqqwzchPnnyRsLBXMJs7oCpkOlN7pWs0ag9Z2wqd5aiL\nQ0vs5fc3lGUzHXVn8aHl+TxUM7OHgR9Qwr4AlZf/EVv7BvJoxTKmMZN5DiPehSoYnYOq9P8dSOav\nQDLWuxn1WcW6EbwJtwq+pmkNH+RAVFSIGyJxPRJn3Uyd+rnFyvkMW8E2GNqiaUtQAh2DdW/YYIqK\nqhk6tA2pqX6orQIN1M6hg1Btix07YZajBFvvn3MUlbFnoCyhLOy3F1mG6pXzrOW5EahGadZiSj/2\n8Sce4BKUfeM4Iqj7h2LgZ+C+las5tKyYgwdX07GjCU2rICtrNV27lrBgwUgiIs7/34ov/H36Qozg\nO3E6g1sFPyYmhqysrJrHubm5REdHn/E9vtDlLyrKN7oReiLOffuCUNKod3pUQqtpLVFTnX1Q2fda\nrFn+cNLSnqBt21YUFenyOgwl4hEosR9qeY/tRWAr1lYJlUCZ5fl01DKnO1HzAbaSHYK6IDjePagW\nyi1ZwZ9ZSRRK7B0bJ+9EXbIOA/P4KzCPqsX1t2uuqjr/f9O+8PfpCzGCb8XpDG4V/EsuuYQjR45w\n7NgxoqKiWLNmDa+99po7hxRsUOWOq5zaOMRVqD4wBajVrh+ibJTTqHLIO1HSeT3wH2xFt7z8QgwG\n6ySpyqFjUcuWdGtoKMoaCkNV1kSiBL81+mbg1sVYRcBbqAzfceGVfZ8cP78f8av+HxN5kWiUQXQZ\nSuyHYu/qlwDbCWAZe1AXDqSDpeAzuFXw/f39mTVrFvfddx+apnHrrbfKhO15xFrueP42qU5OHszW\nrQvIyrLtJvMSyh/XbZxAVAXMIpS9UwSUU1b2V1Stew+sE7ftUNXtF6JEvhR1Z6B78Ccs4ziutu1v\nGddoOacR5d/HoCqBVJ+cmBgThTmnmcrrNVO8rbGK/TqsG5OUAIb7JnHqxLWQ0s0ynvj0gu/g9jr8\nAQMGMGDAAHcPI9TBwYPBOL9xyJlxdpvB8PAwoqN7k5VlK8DtgV9RkqnbOONQojsCdVF4EXVR6Im6\nIPwN6wVjBipTvxh157Ac1YuyBEgE3qb2att1KMG3nW69HVXW+RtgwJ9D9Mx5mb5Qk9lXAL9YzqqL\n/fdArrErr//8ExVVgRQUmJESS8EXkdYKTZiuXYvZurXhjUOcwXGbQb2WPiOjNSbTXiIiutC9eyXJ\nyYOJicnDXoBboeri11HbT9d/jwIWo5oRnLa8pxRVNtkVtSh8JKoLpm255nLUBWW25Ryhlvd84zBW\nBaq0cwYQTkve5EFeJghVgW9bxT8Pa9u0X4G+H3zMjcNGEGbZSUpKLAVfRQS/CbNgQTxlZWfORPXM\nvS7hts3gHfvc/PBDMWbzg+gymZX1Pjt2BLFmzVpURfoLqJbCe1GLnlRFTm0/HcvvIVgbmC1Bib6+\n4YgJlYNrqDsAx7qZwyg//3eU7/8ZyhKy9sDx89tD9+4XcOD3KQzj31yE2tMqHzUlbHvGCNQ9QC6w\ns/f/MX2YbR2/IPguIvhNmIiIhjNRxxWxWVnL2bFjZK1NRxy3GSwpsb0AFKJE/lkqKx27wFdYfgaj\nJmuXo7L3LagFSgtQ3vpfLOcyoKpt9K0D9bJJtayp9sYkO1Fi74eaatVQ62BNGAwLMBgKCAgopLz8\nSTJ/f5XH+DctsC/UdOyGvxdY3fmv9L7iYj4Su0ZoQojgN3McM3clzKvsNh3ZsuUFIiO70bLlLMrK\nugIFVFaWohqSrUMJdBeH8/RC9YzvgzJJ/FENyu5CyeoW1MpWfd1qqOW9GsqnL0RV46iWxMHBwbRu\nncHJkyGcPDkLuBJ1FzAFa0Y/E1sZ17TOaFoorQK2MLG8KxEUcrHlXbaRhqIWUYUDu6KiefS7//FE\neIQLvl1B8C5E8Jsp1s1Gcqg94Wm/yjUnpx05OX7AH1AZ9RSsbvdc6l4odQRrqeQIm2MvRpVq9gJS\nULtP9QdmWc5/EjVluhblxa8FWtO2bQGXXRZJaupk9L48+lixsVmUlMTY1PAb0Dcab8F7TDr1Fn1Q\nS7JKLKPbRhqGarV22YrV3DZgoAu+XUHwTkTwmxG2lTbHj+8kK+shlOwtIyTkFOXlBygr80ctWtJ3\nnApFudlTsIq33mCgN9YLg75QKhrlpV+OfR6t7yk7EiXYek2+vvq1o+U1nWLgH+hZe1aWRnb2U8BS\nlGe/gKCgIMLDs4mIMFJdfYCiIpvaevZzEy25iHIuR80QBAKnLL+/iJrizQUyW7Zk8jdb6Ny1G4LQ\nlBHBb0bY+/WjgPctrxRQXNwG5YPbNv3VO0teQG3bR29yZrtQqhglpzEoJ9w2j24DVGP1823PV4i6\nSNger0u09ThNuxp1UVCWTXi4ucZ6ggJiY+cSHd2bnN8/5s8nV9ADlc3bVuC8iqrSH44ylHq/9TZT\n7ri7MV+rIPgMIvhNGMeVthkZAdgLbQFK0O+zPLbfzi8oKJLQ0AxycsB+Zep2AgK+oby8M0pC26EE\nXpU8qvLKu7GuUf0NdRFohVoEFYSSXF2GQ1H5ti7HJSg7Zz72F4GdqBYIYfj5taekRP8cAOFEtA1j\n4P4JtDhZXNPw7ANq32f8AmwA/mgptxSE5oIIfhPmgQdSSElR1S7p6RrR0S9gK6ABAcFUVtq2Frbv\nHBMenkV09CXk5NyAEu8yoAXV1X+mvPxfqIlafU3qfJTYg8r030bZNDstj5+ynONFVEbvKO6foTJ6\n2wtBEQbDU5bM/iQwGVXeeSctWhykqKhnTbxBPMWf9syhB8qmsZ3yrVXT87fnmPlIoku+Y0HwJUTw\nmzCbNtlPvppMkSi/3AAcoqoqFNiBVWSHoiZXewO7aN26DXv2bANyUPZNN5QL/jFK7HeiRPtt7Gvs\ny7BfHDUHqxVkQElxLPbibsDffzdVVXrXSwNt215ARUUwpaW23n4psbFzadu2M3v2DMOfF3iAp4lE\nTfmWoNbTrkXdY4xCFYgagX1Az7feZqRYOEIzRQTfx6ivxUFdzzvWo1RXF6AmX5cBT6BpytYJCHgG\n6EBl5WGUtbINuICMjN/RtBmo0kt9kdW/UBuRLMde1Gdh3SzcfkMSgyHC0iq72CaeocTEvMjp07EY\nDCauvroNYCQ19YGacw4atJStW49SWmr9DLGxOaSnTyMh4VP279nCgzxNR2qvvT2NMpbyLaMa3nqb\nv4rQC80cEXwfo74WB5s2VWI2twRuID09FFjK1Ve3JDX1JZS1coyIiALy8x0nTcMJDu6C2TwRtcI1\nEHgMNUmql17G2hyvi7njxGtfVG/6LNRCKqtI33ijgV275pKV1cVyvjhiY/ewceM9hIeH1bSgLSgw\n06KF/cpgs7mQMWPmUlDQifDwTFauHMnRjAx6bnyEKyiqKcB0XHt7CDUzkN82lAlfbpIKHEFABN/n\naKjFgV4yefhwW7p00YBpNa/17fsOO3a8YKmpt9opRUU5qOy8LepPwlY+9SZluoAXYW2LYOuOH0Jd\nWKpQ3S5fom3bKAYNakFy8jAAkpI2cvhwT4uYj6/VfK2uHjXh4WGkp0+refzh669wdO7zRKNaIDiu\nHNCAHwEzEP3W20yXrF4QahDB9zEcWxyoCpnaJZNGYxHHjkXYvZafH0OfPhXk5LRCTZqGAK2orn4I\nlQ8/i6qksfXWT6ImVZfTqlUZoaEZ5OYuQS2YeslyjmKUp1+NukNQq2mvu+49OwFvTMOxXdu28Wn8\nYDprGiGWT51nGd22Z/1m4GhQax7/+nvJ6gXBARF8H0H36A8caENs7BxLk7MqysurSE21XgDCwvYw\ncGABTz55Bbfeuhpb8TYai9i06TQwFaugv4+qfAFlyVyCtY/8LtS+sGHAnUREzKWg4EJUnxtFQMBz\nVFY+jaqLWYtqoaA2B8/OjnTJZ1/98VL2JD7MGyhhn24T/VxUfVBHVGfLuLfe5nHJ6gWhTkTwfQTH\nJmeXXvoe0IKjR1sTGzuXdu3i6NbtFMnJdxIeHkZCwkoyMyej574xMb9SXh5JUVE7lH0TjxLyAlQd\n/nKs1TS6NdQH+CcBAZFERBwnK2sq6uKgoQt8dXVny/kqsDY8U6tnjcZKwPle+nXxzovPU/LmK/S0\njOLY2bIT6rK0u20od4lXLwhnRATfwzgrho7e/Y8/+mE2Wy8AV11l3c3KZDKzaVMlqi7+LgCOH99h\n6UNjK+h3Uv8kbAVwjKFDI/jww7sZMiSN48f15z9E7Vg1nerqcMv5PrR7f1jYaZKTbwZqTzQ7s/NW\nWspnbEmYQBuse8sOpfZWJzuB61es5o/SA0cQGkQE38M4K4a1vXt9az8A+92sZszYiNlchbJWQoAi\nqqvD7I4PCqrA3382JSX+qBW2O7H37gNRxY7v2Yy/FvssXu+pY8DP7xjV1db4Bg4MqLlwOV6sGtp5\n6z8LF3D0bzO4Cvu2CMtRl6VZqLqhA35+jFi3kd59Lz/j+QRBUIjgexhnxVDV1VtLFsvL29h590Zj\nUc3dwvr1oNab2u4rOwfb3HjIEPjii3KsneGvB55BbczdApVPG2p8+OTkwWza9CVms2MBJIDGLbdE\n1Cqp1HG8WNW389aWDRv45s7RdEXtcWVfza9G247aB6vlW28zQ7x6QTgr3Cb48+fP5z//+Q/t2rUD\nIDExUfa2rQNnxdCxZLGumvWkJFuf374vjiqvXEZY2GkGDgwgOXkQ69dX2RwTDlyFqrixdrLU4wkP\nD2PgQH9SUmwXQe0gOrqamJh8gHptKceLVV07b+kWzlUood9B7f2xvgcKWrfmwY1SgSMI54JbM/yJ\nEycyceJEdw7h8zgjhnWhXwD0rP6OO7ZZetvrXWRKsG5Q0gb4BX//Mq65xkhy8gjCw8MID8+yW8Wq\nO+V610nHeGrHOr5mgjgl5X7qs6XOtAfsrm3bWDlyCK0qKojCauH0x9pBPwzVmu1SaYsgCI3CrYKv\nVmoKZ+JMYujMhO6jj37BunVKbK37wd4DDMVgeBlNe9Hy2giqql4lNXUKLVooQV65cpRlFWsHNO0A\nXbv2IC5udZ2LovRY580bVBNTUtIGkpMHn7VHD8q++e6uMcRpGq1QbdG+xv5+oyewBwiZ+wp/u39S\ng+cUBOHMuFXwP/74Y1JSUrj44ot54oknCAkJcedwTQZd6Otql+B4cfjxRz9sxTYg4DQXX/yZpea+\nu4PnrpqS6YLctavRbhWrM9Q1yWw0ak7ZUjqrP17KvsSHa/bK0uvpY7G3cHYAwdNncKeIvSC4hEYJ\n/sSJE8nPz6/1fGJiInfffTcPP/wwBoOB119/nblz5zJnzpwGzxkV5RsXBXfFeeKEmZtu+pjMTH17\nQGs1TFZWeK1xDYYT2MpkSEgxv/zyIACjRy+289z1n3FxpWcV/4kTZh56KJWDB4PZv78a2wtMVlY4\n69Zdz5Qpyzl4MJiuXUtYsGAkERG1z3/49995+7rr0PLyuBb7GYYY1KaFy1BtEQ4FBjLhhx+49Mor\nnY7zfNDc/z5diS/ECL4TpzM0SvA/+OADp467/fbbmTx5slPH5uUVNyak84Le7MsdJCSsIjPTdutA\na7uE2NiCWuNefXUbUlP1LpXFXHGFH6NHL7HYQBUMHfoemZlhnDixj4gII927L2X27EHk5RU7vQYg\nIWGVzWSw/d61sbEFVFX5M3/+8Jrjq6pq/ztu2bCBNXeOph2qyfJOVF2QXsV/AHVZy42MYsSaL7nN\nMinrTX8P7vx3dyW+EKcvxAi+FaczuM3SycvLIyoqCoAvv/ySuLg4dw3VpFB2i307ML1dgu0Eqi7W\nmZnRxMbutWm10NbOchk1ailpabcAt9Qay9k1APYe/TDCwl6hS5cLnZpkLjSZ+PD2MZz67Rcisd9A\nUe+8vxXVxrjrW2/zkEzKCoLbcJvgv/zyy+zevRs/Pz86duzI888/766hmhQxMXnArVhbIvzGpk33\n1Mq8HVst6CtthwxJw9kJVGcnW+1LR0MZOLA9ixbd2OBnKTSZePe6fgSeyKcDEIf9fUsUkI5aQnaD\nbDcoCG7HbYKfnJzsrlM3aQyGSlS/GmXRXH55O6daLehi7Wxd/9kce7alo4UmEx/eOoLAHdvpgloC\n1prabYx/B4au38TAmwf4xG2zIPg6stLWy8jO7oCavtQff1bncfWJta04x8WVMnt2/eLsrJCfqXTU\nkS0bNrDqztG0R7VAsF3nexfWNsbfA/1XrJa2CIJwHhHB9zIam3XbinNDE05nI+QNUWgy8cn/3U7Z\nT/8jGrVm19a+CQcWoBZRHbj8Sh5Y/gmh4REuGVsQBOcQwT+POFMV8+STV7J1q76l31FmzhxV57lc\nKdaNpdBk4tUr+hB56iRdgQyUjWNr32QBtGzFdau/kKxeEDyECP55xJmqmLlzfyYr60nAQGmpxpw5\nS1m0yOiJcJ1CX0TVDvsKnGewt286zn1FFlAJgodp8oJfV1ataZzzhhyNwZmqmHNpU+AJ0lI+Iz1h\nAj1QmyN2wN7CuQDV1fJXoK9U4AiCV9DkBb+urBo46w05XIEz/vzZVNl4gkKTifWJD3MkdY1da4SZ\n2Fs4+1FCP12EXhC8hiYv+PVnzOc/i3amKuZcu2eeD45mZLDwun6EVVfVqqmPRO2EG4US+/C/PSdZ\nvSB4GU1e8OvOmM+u2ZercGai1ZsmY3UKTSa+eHgSpWnrCUNtObgT1XxZb41QBJQBmZf2JfG/n0kF\njiB4IU1e8OvPmL0zi/Y2jmZk8J8Bf2RuRTnLgenozZZVa4RoVEYfcNUfeOCj/4jQC4IX0+QFv76M\n2duyaG9k17ZtpA4dRFfq3ua8N/CjwY/79hwQoRcEH6DJC75w9hzNyODzUX/C73guc4FXULZNMbW3\nHBz6xUYRe0HwEUTwBTv0rL43qtfNEdTq2GUooX8JaAvkRbfn9tVfyN6yguBDiOALgEXoRw7hgooK\nLgGGoerrXwKmAGtRk7SFLVpy7efrZbWsIPggIvjNHL0C50Taeru6erXHFrRHZfeHUZ0tbxOhFwSf\nRQTfDTi7k5SnSUv5jG8SJhAFGKlrjy3YB1RFRnHXmi/FvhEEH0cEvwHqEu+GthNzdicpT1JoMvFz\nwgQ6A0+gsnjbCdm9wGZUVi/2jSA0DUTwG6Au8f7sswlnfI8398PRWyPkfLWei4BAVKTxKBvnJJCN\n2nLwZulXLwhNCj9PB+DtnIt4G42FqDwZvKkfztGMDN699CJOpK7huYoKgoBjqEjDgDtRi6hCbhzC\ntL2H+OOAgZ4MVxAEFyMZfgOcSzMzb+yHU2gy8emga7m2vAwT1qx+CfA00BXYHxjI7d9tFa9eEJoo\njRL8devWMX/+fDIyMlixYgV9+vSpee2dd97hk08+wd/fn6eeeor+/fs3OlhPcC7i7S39cMwnTvDf\ne+6hcPO3VBYXM1vTMAAfY83qpwFzAgOpvOkW7ntjviyiEoQmTKMEPy4ujvnz5/P000/bPZ+RkUFq\naipr164lJyeHiRMnsn79egwGQz1n8l68RbzPlqMZGSy89goiNI0eqEVU24FLUTX2r6I6XO5v2Yp7\nf9sjQi8IzYBGCX63burWX9M0u+fT0tKIj48nICCATp06YTQa+e2337jssssaM5zgJLp9E6VpdrtQ\nPY0S/FDADJyIiua2z9eL2AtCM8EtHn5ubi59+/atedy+fXtyc3PdMZTgQKHJxL8HX0e306VUYF9b\n3xVYDByL7ci9GzeL0AtCM6NBwZ84cSL5+fm1nk9MTGTw4MF1vscx4wectnMaqnH3FrwtTvOJE6Q8\n8ACZa9Yws6LCzqvXM/x9QI9Ro3j4/fcJi/Ausfe277M+JE7X4Qsxgu/E6QwNCv4HH3xw1ieNiYkh\nOzu75nFOTg7R0dFOvTcvr/isxzvfREWFeE2cu7Zt48sx8XQ9Xcpx7FfMDgNeADqixF7fW7aiyru+\nZ2/6Ps+ExOk6fCFG8K04ncFldfi2Wf3gwYNZu3Yt5eXlHD16lCNHjnDppZe6aijBhi/HxDP7dCn3\no1bM7sW6AiAU8IvtyIC9h5h+vEi2HBSEZk6jPPyvvvqK2bNnU1BQwOTJk+nZsyfvvvsuPXr0YOjQ\noQwbNoyAgACeeeYZn6zQ8WaOZmSQOm443U6X2vn03VFtEsqBnE6duCPtO/HqBUEAwKDVZbh7EF+5\nffJUnIUmE9/OeIyDa1fzXEUFy1BdLXWffhYQGhZG62uu488fLaGiKtAjcZ4NvnTbLHG6Bl+IEXwr\nTmeQlbY+gi702qYNtDSb6YZ9D5xS4ECrIG5eta6m/01YhG/8sQqCcH4QwfcRvp3xGPemfGqXydv2\nwJkT25G/pO/2ZIiCIHg5IvhejJ7Vhx4+hHbogJ1X3xO1G1V7Pz+yYzowdOUazwUqCIJPIILvxdhm\n9Y419dlhYcQMHMz1ya/JpKwgCE4hgu9l7Nq2jS9G/wljWRm5wALgblRN/SthYXTv0o1CYxfGiNAL\ngnCWiOB7GV+OiefFsrKaTH4ZkIry6SMHDub6RYs9GZ4gCD6MCL6X0a3stJ1XHwKYgoJYPGQo1ye/\n5sHIBEHwdUTwPYztxGyh0cjuwBZo5dYMvxioHjKU4ZLZC4LQSETwPYxduWX6z/x94CCe+vF7jGVl\nHDcYaHX9QMZIZi8IggsQwfcwoYcP2Vk4nQsLuftonidDEgShiSKbmJ9HCk0mPk+4l2+H3MDnCfdQ\nWGCi0Gi02e4cCo1dPBihIAhNGcnwzyOO9s1iDFyf/DqLMVg8/C4yMSsIgtsQwT+PONo3oYcPERoe\nIROygiCcF8TSOY+IfSMIgieRDN8NOJZaXp/8OqHhEWLfCILgUUTw3UBdXv3wRYvFvhEEwaOIpeMG\n6vLqBUEQPI0IvhsQr14QBG9ELB03IF69IAjeSKMEf926dcyfP5+MjAxWrFhBnz59ADh27Bjx8fF0\n69YNgMsuu4xnn3220cH6CuLVC4LgjTRK8OPi4pg/fz5PP/10rdcuuOACVq5c2ZjTC4IgCC6kUYKv\nZ/CapjVwpCAIguBp3DZpm5mZydixYxk/fjw//fSTu4YRBEEQnKTBDH/ixInk5+fXej4xMZHBgwfX\n+Z7o6Gi+/vprQkND2blzJw8//DBr1qyhTZs2DQYUFRXiRNjnD/OJE6Q+9BDBBw9S3LUr8QsWAN4X\nZ31InK5F4nQdvhAj+E6cztCg4H/wwQdnfdLAwEBCQ0MB6NOnD507d+bQoUM1k7pnIi+v+KzHcyef\nJ0yyLqLaupXFZZVM/OwTr4uzLqKiQiROFyJxug5fiBF8K05ncJmlY+vjm0wmqqurATh69ChHjhyh\nc+fOrhrqvCKLqARBaCo0atL2q6++Yvbs2RQUFDB58mR69uzJu+++y08//cTf//53AgIC8PPz4/nn\nn6dt27auivm8Umg0oqX/XLPloCyiEgTBV2mU4N90003cdNNNtZ4fMmQIQ4YMacypvQZZRCUIQlNB\nVto2gCyiEgShqSC9dARBEJoJzVLw69pbVhAEoanTLC2d+vrVC4IgNGWaZYYvpZaCIDRHmqXgS796\nQRCaI03e0qlrf1kptRQEoTnS5AW/Pr9ePHtBEJobTd7SEb9eEARB0eQFX/x6QRAERZO3dMSvFwRB\nUDR5wZfWCIIgCIomb+kIgiAIChF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCY0SvCT\nk5MZOnQoo0aNYtq0aZSUlNS89s477zBkyBCGDh3Kd9991+hABUEQhMbRKMHv378/a9asISUlBaPR\nyDvvvAPA/v37SU1NZe3atSxatIjnnnsOTdMaOJsgCILgThol+Ndeey1+fuoUffv2JScnB4ANGzYQ\nHx9PQEAAnTp1wmg08ttvvzU+WkEQBOGccZmHv2LFCgYOHAhAbm4uHTp0qHmtffv25ObmumooQRAE\n4RxosLXCxIkTyc/Pr/V8YmIigwcPBmDBggUEBgYyfPhwgDrtG4PBUOs5QRAE4fzRoOB/8MEHZ3x9\n5cqVbNq0iSVLltQ8FxMTQ3Z2ds3jnJwcoqOjnQooKirEqeM8jcTpWiRO1+ILcfpCjOA7cTpDoyyd\nb775hnfffZcFCxbQokWLmucHDx7M2rVrKS8v5+jRoxw5coRLL7200cEKgiAI545Ba0T5zJAhQ6io\nqIMzjrUAAATvSURBVCAsLAyAyy67jGeffRZQZZkrVqwgICCAp556iv79+7skYEEQBOHcaJTgC4Ig\nCL6DrLQVBEFoJojgC4IgNBNE8AVBEJoJXiv47733Hj179sRsNns6lDp58803GTlyJKNHj+b+++8n\nLy/P0yHVyZn6HXkT69atY/jw4fTq1YudO3d6Ohw7vvnmG/70pz9xyy23sHDhQk+HUy8zZ87k2muv\nZcSIEZ4OpV5ycnKYMGEC8fHxjBgxwq6c25soLy/ntttuY/To0YwYMYL58+d7OqR6qa6uZsyYMUye\nPLnhgzUvJDs7W7vvvvu0QYMGaQUFBZ4Op05KSkpqfl+yZIn29NNPezCa+tm8ebNWVVWlaZqmvfzy\ny9orr7zi4YjqJiMjQzt48KA2fvx4bceOHZ4Op4aqqirtpptu0jIzM7Xy8nJt5MiR2v79+z0dVp1s\n3bpV27VrlzZ8+HBPh1Ivx48f13bt2qVpmvo/NGTIEK/9Pk+dOqVpmqZVVlZqt912m/brr796OKK6\n+eCDD7Tp06drDz74YIPHemWGP2fOHJKSkjwdxhlp06ZNze+lpaU1PYW8jfr6HXkb3bp1o0uXLl7X\nZO+3337DaDTSsWNHAgMDGTZsGGlpaZ4Oq0769etH27ZtPR3GGYmKiqJXr16A+j/UvXt3jh8/7uGo\n6iYoKAhQ2X5lZaWHo6mbnJwcNm3axG233ebU8Q2utD3fbNiwgQ4dOnDRRRd5OpQGef3110lJSSEk\nJMRrb01tWbFiBcOGDfN0GD5FXX2htm/f7sGImg6ZmZns2bPHaxdlVldXM3bsWI4cOcKf//xnr4xT\nT46Li4udOt4jgl9ff55HH32Ud955h/fff7/mOU9mfA31EUpMTCQxMZGFCxfy0UcfMW3aNA9EeXb9\njjzp7zoTp7fhbXccTYWTJ0/yyCOPMHPmTLu7ZW/Cz8+Pzz77jJKSEh566CH2799Pjx49PB1WDV9/\n/TWRkZH06tWLLVu2OPUejwh+ff159u3bx7Fjxxg1ahSappGbm8u4ceP473//S7t27c5zlA33EdIZ\nPnw4Dz74oMcE/1z6HXkCZ79PbyImJoasrKyax7m5uU73hRLqprKykkceeYRRo0Zx0003eTqcBgkO\nDuYPf/gD3377rVcJ/s8//8yGDRvYtGkTZWVlnDx5kqSkJJKTk+t9j1cZz3FxcWzevJm0tDQ2bNhA\n+/btWblypUfEviEOHz5c83taWhrdunXzYDT1U1+/I2/Gm7LqSy65hCNHjnDs2DHKy8tZs2YNN954\no6fDqhdv+u7qY+bMmfTo0YN77rnH06HUi8lkqrFJTp8+zQ8//OB1/8cfe+wxvv76a9LS0njttdf4\n4x//eEaxBy/08G0xGAxe+wf86quvcvDgQfz8/IiNjeW5557zdEh18sILL1BRUcF9990H2Pc78ia+\n+uorZs+eTUFBAZMnT6Znz568++67ng4Lf39/Zs2axX333Yemadx66610797d02HVyfTp09myZQtm\ns5kbbriBadOmMW7cOE+HZce2bdtYvXo1cXFxjB49GoPBQGJiIgMGDPB0aHbk5eXxxBNPUF1dTXV1\nNfHx8TX7ffgy0ktHEAShmeBVlo4gCILgPkTwBUEQmgki+IIgCM0EEXxBEIRmggi+IAhCM0EEXxAE\noZkggi8IgtBMEMEXBEFoJvw//5K32R/vBHAAAAAASUVORK5CYII=\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f5be3c99f50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Current loss: 9.48636\n" + ] + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.scatter(inputs, outputs, c='b')\n", + "plt.scatter(inputs, model(inputs), c='r')\n", + "plt.show()\n", + "\n", + "print('Current loss: '),\n", + "print(loss(model(inputs), outputs).numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "sSDP-yeq_4jE" + }, + "source": [ + "### Define a training loop\n", + "\n", + "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "MBIACgdnA55X" + }, + "outputs": [], + "source": [ + "def train(model, inputs, outputs, learning_rate):\n", + " with tf.GradientTape() as t:\n", + " current_loss = loss(model(inputs), outputs)\n", + " dW, db = t.gradient(current_loss, [model.W, model.b])\n", + " model.W.assign_sub(learning_rate * dW)\n", + " model.b.assign_sub(learning_rate * db)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RwWPaJryD2aN" + }, + "source": [ + "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 446 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 569, + "status": "ok", + "timestamp": 1527005915434, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "XdfkR223D9dW", + "outputId": "c43591ae-d5ac-4f2b-a8e7-bfce607e0919" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0: W=5.00 b=0.00, loss=9.48636\n", + "Epoch 1: W=4.58 b=0.42, loss=6.28101\n", + "Epoch 2: W=4.24 b=0.76, loss=4.29357\n", + "Epoch 3: W=3.98 b=1.02, loss=3.06128\n", + "Epoch 4: W=3.78 b=1.23, loss=2.29721\n", + "Epoch 5: W=3.61 b=1.39, loss=1.82345\n", + "Epoch 6: W=3.49 b=1.52, loss=1.52970\n", + "Epoch 7: W=3.38 b=1.62, loss=1.34756\n", + "Epoch 8: W=3.30 b=1.70, loss=1.23463\n", + "Epoch 9: W=3.24 b=1.76, loss=1.16460\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW0AAAEDCAYAAAD+/1UIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VOXdPvD7zJZ9XwmELQkQIAELsiTsi6xiEBGXAiIW\nbV8WBY2K0tLa4lbsr283qxURtIoioAi8SpFNg6whi0FJKAoJBgLZt5k5c87vj5OZLIRkgEnOGXJ/\nritXJsmZyT0sN1+enPOMIMuyDCIicgs6tQMQEZHzWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu\nxODMQePGjYOvry90Oh0MBgM2b97c1rmIiKgZTpW2IAjYuHEjAgIC2joPERG1wKnlEVmWIUlSW2ch\nIqJWCM5cETl+/HgEBARAEATMmTMH9957b3tkIyKiJpxaHvnggw8QFhaG4uJiLFiwAD179sTgwYPb\nOhsRETXh1PJIWFgYACA4OBgTJ05EVlZWi8fL3t6AIADdugFvvglYrTeflIiIWl8eqampgSRJ8PHx\nQXV1NR5++GEsXrwYI0aMuPadCgtRvfoFeL2zDkJtLWxdu6NqRSrMs+8DDE4N9y4XFuaHoqIKVb73\ntTCTc7SYCdBmLmZyjlYzOaPVSfvy5ct44IEHkJKSgjlz5mDcuHEtFzYAREai6oWXUHwkA9WPPApd\n4QX4L/sVgpMGwWPTvwFRdCocERE15tQPIm9Ew3/FdBcK4P3ntfB89x0IVivEmFhUP/kMzCmzAL2+\nLb79VbT6LysztU6LmQBt5mIm52g1kzPa5YpIKaozKl9+DcWHT6Jm7gLof/wB/r98BEGjh8Fj28cA\nTyckInJKu17GLnWJRuXaP6P40AnUPDgP+jN58F+0AEFjhsO0fRvLm4ioFarsPSJ1647KP/0VxWnH\nUTvnAehPf4+AhfMQNG4ETDs/A/hiOkREzVJ1wyipR09U/OV1lHx9FLX3zIH+uxwEPPQAAieMgunz\nXSxvIqImNLHLny0mDhV/fxMlB4+g9u57YMjORMDcOQicNAamPV+wvImI6miitO1scb1Q8fo6lOz/\nBrUzZsJ4Mh0B99+DwKnjYdy7h+VNRNftL395DR999IHj4+XLl2DVqlWOj//61/+HDz/8txrRboim\nStvO1iceFf96B8V702CeNgPG48cQOGcmAu+cBOOBfSxvInJa//6JyM7OAKBsfldWVorc3FzH17Oz\nM5GQMECteNdNk6VtZ+vXH+Vvv4uSPQdhnjwVxiPfIPCeGQhImQpj2ldqxyMiN5CQMBBZWZkAgLNn\nz6Bnzxj4+PigsrISVqsVP/74A+Liequc0nnqXFN+ncSEASjf8AEMJ0/A+9UX4bH7c5hSpsIycjSq\nnloJcdhwtSMSkRN8Vj8Pj+3bXPqY5jtTULX699f8emhoKPR6Ay5duoisrEz075+I6uoyZGdnwsfH\nBzExsTCotL3GjdD0pN2UOPBnKH/vI5Ts2gPL2PEwHdyPoBmTEDD7LhiOHlY7HhFpVGJiIrKyMpCd\nrZT2gAEDkJWVgaws91oaAdxk0m5KHHQ7yjZtheHIYfi8sgam/Xth2r8X5vETUZ26EuJtg9SOSETN\nqFr9+xan4rbSr18isrIy8d//KssjHh4y/vnPf8HX1wfTpt3V7nluhltN2k2JQ4aibPMnKP1kFyzJ\nI+GxZzeCJo2F/8/vhSHzpNrxiEgjEhIGIC3tIPz9/SEIAgICAlBZWYHs7Cz075+gdrzr4talbWcd\nnoyyrTtQuuUzWIYlweOL/0PQhFHwn3c/9HU/gCCijismJhbl5WXo3z+x0ef8/Pzg7+9er33bLrv8\ntStZhvHAPvi8/AcYjx0BAJin3wWPXz+Hom69lRdn0Ait7jTGTM7RYi5mco5WMznjlpi0GxEEWEeP\nRemO3Sj9YAusPxsEj88+AYYMQeCEUfBc/xaEinK1UxIR3ZBbr7TtBAHWcRNQuutLlG7aCsycCUNO\nNvxSn0BIQm/4Ll8CQ/pxXqhDRG7l1i1tO0GAdex4YMsWFJ88hapnV0EKCYHXu+8gaNJYBI4fCc+3\n/wWhvEztpERErbr1S7sBKSIS1U88heKjmSj9YAvM02bAcOpb+D29HCGJveH7xGIYThzj9E1EmtWh\nSttBp4N13ASUv/2uMn2v/DWk0DB4vbcBQZPHIWjcCHiue5PTNxFpTscs7QakiEhUP/4kio9koHTT\nVpinzYD++1Pwe2aFMn0//j8wHD/K6ZuINKHDl7aDTgfr2PHK9J2eg8rnfgMpNBxe/96IoCnjETQ2\nmdM3kZsqLPwJ8+bNUTuGS7C0myFFRKJm2QoUHzmpTN/T74L+9HfK9J3QC77LfgXDsSOcvonciKCh\nazRuBku7Jfbpe91GXEk/hcrnV0MKj4DX++8iaOoEZfp+6w0IZaVqJyWiVoiiiD/8YTXmz78fy5Yt\ng9lsVjvSDbn1roi8BpddASVJMB7YB6+N62Ha9RkEUYTs5QXzXXejZu5DEAcPcfqqS61elcVMztFi\nLq1nWr3aA9u3u3afujvvFLF6dcsFXFj4E2bPnoF//GMd+vdPwJ/+9CI6dYrGfff93KVZbkbHvSKy\nrel0sI4Zh/K3NuDKye9Q+fxvIYVHwPOD9xA0bSKCxiTB861/cvom0piIiEjH5lAzZsxAZmaGyolu\njFtuzaoVcng4apY+gZrFy2A8uB+eG9fDY+d2+D37FHx/92uYZ8xEzbwF1zV9E93KVq82tzoVt5Wm\na9ru+leSk7Yr6HSwjh6Lin+9o0zfq34HKSISnpv+XTd9D4fnv16HUFqidlKiDquw8Cd8+202AGDH\njh1ITByocqIbw9J2MTk8HDVLHkfxN+ko3fwpau+6G/q8XPitTEVIYm/4LXkMhiOHeeYJUTvr3r0H\ndu36DPPn34+ysjKkpNyjdqQbwh9EtgOhqAiem/4Nz41vw3D2vwAAsU88DAsfxpVREyH16KlKruZo\n/QdZWqLFXMzkHK1mcgYn7XYgh4WhZvEylBw6gdKPt6M25W7oz+QBTz2FkKEDETR6OLxf/gMMWRmc\nwImoRfxBZHvS6WAdORrWkaNReeUKQr/eA/Omj2A6sA8+a1+Gz9qXYYvuCvOUabBMvRPWIcMAN3qV\naCJqe2wElcghIcDChSifcS+EygoYv/wPPHZ+BtPuz+H9xj/g/cY/IAUHwzxpKixTpsMyeizg5aV2\nbCJSGUtbA2RfP1hmzIRlxkzAYoHx64NKgf/fDni9/y683n8Xsrc3LGMnwDx1OiwTJ0EODFI7NhGp\ngKWtNSYTrGPHKy/c8PJaGE4cg8euHTDt3A6PHZ/CY8enkA0GWJNGKgU+ZRqkTlFqpyaidsLS1jKd\nDuLgIRAHD0HV86uhP/09PHZ9BtPO7TAd2AvTgb3AMytg/dkgmKdMh2XqnbDF9VI7NRG1IZ494i4E\nAbbefVD9+JMo/WI/rqTnoOLFV2EZOQaGjJPw/cNvEZw8GEFJg+Dz+9XKHuCSpHZqItVVVlZi69bN\nbfb406dPQGVlJQDgypXLGDnydmRlZTT4+kSUl7vuxcSdLm1JkjBz5kw89thjLvvmdOOkzl1Qu/BR\nlH38Ka7knEH5X/8J89Q7oS/Ih/f/voagKeMRPDAevqlPwLjvS8BiUTsykSoqKsqxdetHzX5NcsFg\n07dvArKzMwEA2dmZ6NWrD7KylI/PnfsRgYFB8Pf3v+nvY+d0aW/YsAExMTEu+8bkOnJQMMz33o/y\n9e/h8qmzKHvnfdTOeQCCuRZe699C4L0pCOkbA79fPgLT9m1A3VRA1BG8/vpfceFCAR5++EH8/e//\ni/T045g3bx5++9vnMX/+fVe9QML777+Lt99+EwBQUJCPFSuW4pFH5mHx4kU4d+7Hqx4/ISHRUdpZ\nWZmYM+dBfPttfYknJCS69Pk4taZdWFiI/fv347HHHsPbb7/t0gDkYt7esEyZBsuUaYAowvhNGky7\nPoPHzs/g+fGH8Pz4Q8geHrCMGQfLlOkw3zEFcmio2qmpAwke1L/Zzxcfz3bJ8U398pdL8MMP/8W6\nde8BANLTjyMrKwsbNnyIyMhIFBb+dM0XSHjllTVITV2Jzp27ICcnG2vXvoQ///kfjY7p3z8R69e/\nBQA4depbPPLIY/joo38DUEo8IWGAUzmd5VRpr1mzBqmpqaio0NZln9QKgwHWEaNgHTEKVb9/GYas\nDOUslF074PH5Lnh8vgu+Oh2sQ4fDMnU6zFOmA2HN/wUhupUkJiYiMjKyxWNqamqQnZ2BVauehn23\nD1EUrzqub99+yM39HrW1tbDZbPD09ERUVGcUFOQjOzsD99/v2j27Wy3tffv2ITQ0FPHx8Th8+LDT\nD+zsdfTtqcNnGj9SeVv7CpCbC3zyCYStW2E6lAbToa/hu+pZICEBYWPGAGPGAKNGARqZwrX4ewdo\nM5fmMzWzxAAAYde68/Ue34TFUg69XufIEBjoDS8vL8fHklQNQajPaDQCOp0JwcHeCAgIwPbtn7by\nHfzQvXs37N//OQYMSEBYmB+GDBmMrKxjKC8vw6Br/E/hRrVa2idOnMCXX36J/fv3w2w2o6qqCqmp\nqXjllVdavJ8WN2NhpgYCI4H5jwLzH4Vw8SI8Pt8Jj53bYUr7CsjKAv7yFwCAGN8X1uHJsCSPhHVY\nMuQwZ/+quI4Wf+8AbeZipqvV1sqoqKh0ZCgtrQZQ31GSZMLly1dw5kwBPD09sXv3HgwbloSaGhkR\nEZ3w4YdbMXbsBABAXl4uYmPjrvoeffr0w7p1b2PhwkdRVFSBbt164YUXViE+vp/Tz93Zf2xbLe3l\ny5dj+fLlAIAjR45g3bp1rRY2uRc5IgK18xagdt4ChPmbUPrFPhi/Pghj2tcwHjsMw6kceK1TfjAj\n9u4D6/BkWJNHwjJ8BOTwcJXTE7XM3z8ACQkDMH/+fRg6NAnDhyc3+rrBYMCCBY9g0aL5iIrqjG7d\nuju+9utfv4A//vElvPPOOthsIsaPv6PZ0k5IGIDNmzehXz/llXF69+6DoqIizJgx0+XP57q2ZrWX\n9uuvv97qsfzXvnVukcligSH9BEyHvlKK/OhhCNXVji+Lcb1gHT4C1uQRsCaNgBTR8jqhSzJphBZz\nMZNztJrJGdxPW0VumclqheHkCRgPfQ3T1wdhOHIYuqr6UwjFmFhYk0Y43lxxib0Wf50AbeZiJudo\nNZMzeBk7XR+jEeLtQyHePhQ1S5crJZ55UllKSTsI4+Fv4LVxPbw2rgcA2Lr3UNbD65ZUpM5d1M1P\n5OZY2nRzjEaIg26HOOh21Cx5HBBFGLIylBI/9BWMh9Lg9d4GeL23AQBg69odluQR9SUe3VXlJ0Dk\nXlja5FoGA8TbBkG8bRBq/mcpYLPB8G0WjF9/VV/iddvNAoAtuiusSSNgsS+ndO3mvi+TTdQOWNrU\ntvR6iIkDISYORM0vFwM2G/Q538KUdtAxjXtu+jc8NylXkNk6d3Gsh1uSRkDq3kPlJ0CkLSxtal96\nPWwJiahJSETNo/8DSBL0p3Ial/hHH8Dzow8AALZOUcCY0fDqkwAxcQDEhETI/gEqPwki9bC0SV06\nHWz9+qOmX3/U/OKXSol//x2MaV/BlKYsqeD99+GL9x13EXv0VKb3hAF1RT5Aefk2og6ApU3aotPB\nFt8Xtvi+qF24CJBlhJUWonx/GgyZGcpb1kl4frIF+GSL4262LtH1JZ44AGLiwDY5Z5zcT2VlJXbv\n/j/MnHlPm32PNWt+i+TkkRg9elybfQ87ljZpmyAAvXrBHNQJ5pRZyudkGbr8844CN2RmwJhxEh67\nPoPHrs8cd7WFR9SXeMJAiIkDIHWJ5g86Oxj7ftpNS1uSJOh07vc6MCxtcj+CACm6KyzRXWGZdqfj\n07qLhTBknmwwkWfA4z9fwOM/XziOkYKCHAVuf7N17wm44V9edzVokE+znz9+vMolxzfVcD9tvV4P\nLy9vREVF4ttvc/Dqq39Gaurj2LBhEwBlL+3a2hosWPALFBTk47XXXkFZWSk8PT2Rmvocunbtds3v\nc/ToYXz44fsoKSnG4sVPIClphFP5rhdLm24ZUkQkLBMnwzJxsuNzwpUrMGTVl7gh82T962va7+fr\nBzEh0bE+LiYOhC02DjDwr8etoOF+2unpx5Ga+gTWrn0VRqPfTe+l3VBh4U/429/eRH7+eSxd+hg2\nbdoGo9Ho8ufDP5V0S5NDQmAdMw7WMfVrjUJ5GQzZWfVTeVYGjIcPwXTo6/r7eXlB7NvfsT4uJg6A\n2DseMJnUeBq3FGcn5Bs9vjV9+/ZDVFRUi5exO7uXdkPjxk0EAHTpEo2oqM748ccfmt1c6maxtKnD\nkf0DHOeCO1RVwZCT3WAiz4AhIx3G40fr72c0QozvpxR4/0Rg2CAIIZ2VnQ65Tu42PD09Hbf1ej1s\ntvrXibRYzAAAWZbg5+fveLUbZzSd2K81wd8sljYRAPj4OPZUcTCbYfgup9FZK4Zvs2HMPOk4JBSA\n5B8AW2wsbDFxsMX1glj33tajJ+Dh0f7PhRrx9vZGdd3OlE33xwsKCkZpaQnKy8vh6emJtLSvMGxY\nEry9fdCpUxT27v1Pq3tp2+3d+x9MnjwNFy4U4MKFghbXv28GS5voWjw8IA64DeKA2+o/Z7VCn3sa\nhqwM+F/4EeaMbOjP5MKQlQnjieON7i7rdJC6doMYG+codFtsHMTYXsqLSXA6bxcN99M2mTwQHBzs\n+Jor9tK2i47uhsWLF6GkpBhPPbWyTdazAW7Nqipmco4WMwFNcokidOd+hOFMLvS5udCfyVXKPS8X\nustFV93XMZ3H1he5LTbupqdzLf5aMZNzuDUrUXsyGCD1jIGlZwzQ4OwVABBKS6DPy4U+LxeGuvf6\nvNOtT+f2Iq9bcuF0TgBLm6jNyYFBEAcPgTh4CMwNvyCK0J/7oa7E86DPO11X7KeVc8sbnF8O1E3n\nccpSixjXS1lyccF0Ts7bsGEd9u79DwRBgCzLEAQBY8dOwNy5C9otA5dHVMRMztFiJqBtc101neee\nVpZczv4XgtXa6FjHdB7XCx59eqEyJBK26GhIXaJh69IVcmioqhO6Fn//tJrJGZy0iTTIqem8bu3c\nUFfoHrs/B3Z/Dt+mj+XlBVvnLkqJR3eFFN0VtrpCl6KjIUV2AvT6dnx2dDNY2kTuxGCArWcsbD1j\ngTumNPqSUFKM0MorKMv8Dvr8c9Dln4f+/Pm69z/CkJfb7EPKBgOkqM6wdbFP59GOYpeio2HrHM3l\nFw1haRPdIuSgYKBXN1iir3FaWmUl9PnnlUI/fx76/PPQ5Z9zFLvx0NcQrrFaaguPUAo8uiukLg0K\nvW5al32d+6893TyWNlFH4esLW5942PrEN/91sxm6CwV1ZX4e+vPn6m+fOwdDxkkYjx9r9q5SYKBS\n4F2i69bT64sdiX0A2YNLMC7C0iYihYcHpB49IfXo2fzXbTboLhbWTen1yy/224b/5kHIzmz2rqF6\nPaTQMEjhEZAiIhq/D49s9DG8vdvwSbo/ljYROUevhxTVGVJUZ4hDh139dVmGUFzcYPlFKXPv4iKI\n5wuUrXPP5ELIymjx20h+/pDCwyFFRCrvHcVu/1wEpIhIyMHBHXJLXZY2EbmGIEAOCYEYEgI0uPTf\nO8wPpQ1OrxMqK6C7dBG6ixfr3hdCd+lS3fv6z+v/e+aaa+xA3Q9Qw8KbTO0RjlJvWPJosEmUu2Np\nE1G7kn39YPP1U86AaYnVCt2Vy1eVeePCvwjD96cgZKS3+FBSQGD91B4RAXTtAm9PX0jBIZCCgyEH\nBUMKCoYcEgIpKFjTJc/SJiJtMhohRXZSziNviSxDqChvMq03md7r3gy5px13a/71cOoe0ttbKfSg\nukIPaVDswcH1X6u7LQcHQ/bxbZeLmFjaROTeBAGyfwBs/gHKKw61xGKB7nIRQsQqlJ45D11JMYSS\nYuiuXGl0Wygpga6kGIYzeRCqnXsRBtlobDSty0H1hS4FBSsTfXDj4pcDAq97XZ6lTUQdh8kEKaoz\nEOYHa9dezt3HbFYKvbgYuuIrSrHbbxcX15e9/eOfLsBwKseph5Z1OsiBgZCCQ4AG/wtoCUubiKgl\nHh7KEk1kJ9icvY8oQigtVQq9boq/ZvHX3XYWS5uIyNUMBsihobCFhgJOvkxkmJMP3fFOciQicmMs\nbSIiN8LSJiJyIyxtIiI30uoPIi0WCx588EFYrVbYbDZMmjQJixcvbo9sRETURKulbTKZsGHDBnh5\necFms+H+++/HqFGjkJiY2B75iIioAaeWR7y8vAAoU7coim0aiIiIrs2p0pYkCSkpKUhOTkZycjKn\nbCIilTh1cY1Op8O2bdtQWVmJX/3qV8jLy0NsbAs7dHXvjmDp6i0Vi49nN3t48KD+zX7epcfrhKsy\nqZoHuCqT6nmaZNJEngaZNJPH7tyPmsrD42+N41tzXVdE+vr6YsiQITh48GDLpQ1Ar7t6t6trvkR8\nM8e2xfFNM6mdp2kmLeRpmEkreeyZtJSnxfuolMd+/FX3UznPVffVQJ5GH2skj7MEWW5hl3EAxcXF\nMBqN8PPzQ21tLRYuXIhFixZh9OjRLT5wUYNNz7UgLMyPmZzATM7TYi5mco5WMzmj1Um7qKgIzzzz\nDCRJgiRJmDp1aquFTUREbaPV0u7duze2bt3aHlmIiKgVvCKSiMiNsLSJiNwIS5uIyI2wtImI3AhL\nm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uI\nyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiN\nsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0\niYjciKG1AwoLC5GamorLly9Dr9dj9uzZmDdvXntkIyKiJlotbb1ej2effRbx8fGoqqrC3XffjeTk\nZMTExLRHPiIiaqDV5ZGwsDDEx8cDAHx8fBATE4NLly61eTAiIrrada1p5+fn47vvvkNiYmJb5SEi\noha0ujxiV1VVhaVLl2LlypXw8fFp8dju3QFJuvqY48ermj1+0KDmH8+Vx+t0V2dSMw+AqzKpnadp\nJi3kaZhJK3nszp1r9tOq5eHxt8bxrXGqtEVRxNKlS3HXXXdhwoQJTj2wTnf1EB8W5neNY5t/DFcf\n3zST2nmaZtJCnoaZtJLHnklLeVq6j1p57Mc3vZ/aeZre1kKehh9rJY+zBFmW5dYOSk1NRVBQEJ59\n9lmnH7ioqOKGArWVsDA/ZnICMzlPi7mYyTlazeSMVte0jx8/ju3bt+Obb75BSkoKZs6ciQMHDtx0\nQCIiun6tLo8MGjQIp06dao8sRETUCl4RSUTkRljaRERuhKVNRORGWNpERG6EpU1E5EacviKSiIiu\nnyQBZWVASYmA4uL6t5ISwfG5khIBn37q3OOxtImInGSxoFHRNizgpkWsfAyUlgqQJMFlGVjaRNTh\nyDJQWYmrCvdaU7D981VVzpWvXi8jKEhGaKiMuDgJQUEygoOVt6Ag1L2XHe+DgmQAvk49NkubiG4Z\nNTVAUZGAixcFXLqkw6VLyu2iIuVj5fMCLl8GLBbnLhv39lZKtUcPqVHR1pdw4/INCZHh5wcIrhuu\nG2FpE5GmSZIyEV+6JDhK2F7IDd8uXtShvLzlpvTwkBERIWPgQMDfX2yxfO23vbza6Yk6iaVNRKqo\nqUGjwm1cwvVTcVGRAFFsuYxDQiR07izhtttkhIfLiIiQEB5uvy3X3Zbg769MwMqGUTXt9Exdi6VN\nRC4likBhoYD8fB3OnxdQUQGcPevRYEpWStn5qVhCeLjUqIAblnJYmAyjsZ2enAawtInoutTUABcu\nCDh/Xof8fB3y8+23laK+cEGAzda0kE2OW9eaiusnYuVzbbku7M5Y2kTUSFkZGpVw49sCLl9u/po8\nQZARGSnjZz+T0KWL/U1GfLwnPD2rEBGhnE3RkabitsDSJupAZFlZR25Ywsq0XH+7oqL58dZolNG5\ns4z4eBFdusjo0kVCdLTkuB0VJcNkuvp+YWGeKCqS2viZdRwsbaJbiNUKnDvXtJDrlzIKCgSYzc2X\nso+P3KiEu3SxfywhOlpZtmjppdeofbC0idyMLAM//SQgL0+H3FwdzpzRIS9PeV9QAEhS8xdphIZK\niI9X1pPrC7m+mAMDuYbsDljaRBpVXQ2cOaOUccNyzsvTobr66naNiJCQlARERFgbTczR0TI6d5bg\n7a3CkyCXY2kTqcg+Nefm1heyfWrOz796LcLTU0bPnhJiYxu/xcQoZ1so5x/XqvBMqL2wtInagX1q\nbljK9um5uak5MlLCyJEiYmIal3OXLlxX7uhY2kQuIsvK+csNJ2b7W0HBtafmuDjJUc72277O7R1E\nHRBLm+g6WSzA99/rcOkScOKEqdH03NzU3KmTMjU3XMqIi5PQuTOnZrp+LG2iFtTUADk5OmRm6pGV\npbw/dUoHq9Vezh4AAC+va681c2omV2JpE9WprASys/XIzKwv6dOndY0uyfbwkJGQIKF/fxsGDzYh\nIqIasbGcmqn9sLSpQyopAbKylIJW3utx5kzj1vX2ljF4sA2JiRISEpT3cXGS4zLssDATiopsKqSn\njoylTbe8S5cEx9KGvaTPnWtc0AEBMkaOFJGQICEx0YbERBt69uT0TNrD0qZbhv3sjYblnJmpQ2Fh\n4+YNDZUwbpyIxESbo6S7dpV5NSC5BZY2uSVZBn74QXAUs30N+sqVxgUdFSVh8mRrgwlaQmQkC5rc\nF0ubNM9mA06f1jUq56ws/VWb6HfrJiEpyepYg05IkBAWJquUmqhtsLRJc8xmID1dj6+/1iMtTY/j\nx4Hqah/H1wVBeYXrCRPqp+f+/W0IDFQxNFE7YWmT6mprgRMnlIJOS9Pj2DE9amvrp+h+/YCEBKtj\nDbpfPxvPfaYOi6VN7a6mBjh+vL6kjx/XO/Z4FgQZfftKSE62YfhwG4YPF9G7NzdBIrJjaVObq64G\njh2rL+lQSYIFAAANpklEQVQTJ/SwWOpLun9/CUlJNiQl2TBsmIigIJUDE2kYS5tcrqrq6pK2X/at\n09WXdHKyiKFDuRZNdD1Y2nTTKiuBo0ftJW1AeroOolhf0omJ9klaKemAAJUDE7kxljZdt8pK4MgR\n+9kdBmRk1Je0Xi9jwAAJw4crk/SQITb4+6scmOgW0mppr1y5Evv27UNISAi2b9/eHplIYyoqgMOH\n6yfpjIz6TZT0ehkDB0pIShKRnGzDkCE8s4OoLbVa2nfffTfmzp2L1NTU9shDGlBeDnzzjVLQaWnK\nFYeSpJS0wSDjttskJCeLGD6cJU3U3lot7cGDB6OgoKA9spBKZBk4eVKHXbsMOHgQSE/3dZS00ajs\ndGc/u+P2223w8WnlAYmozXBNu4OyWoFDh/TYtcuAXbsMuHBB2bPDaARuv92G5GSlpAcPtvFVvIk0\npM1KOyzMr60e+oZ19EzV1cAXXwBbtwLbtyt7SgNAYCAwdy6QkgJMmgT4+BigtX/Ptfh7B2gzFzM5\nR4uZnNFmfzOLiira6qFvSFiYX4fMVFICfPGFATt3GrBvnwE1NcqyR2SkhAULREydKiIpyebY2N/H\np2P+Ot0ILeZiJudoNZMznCptWeZOae7kwgUBu3YpRZ2Wpnec6REba8PUqUpRDxwocYN/IjfUammv\nWLEChw8fRmlpKcaMGYMlS5Zg1qxZ7ZGNrkNurg47dypFnZ6ud3z+ttuUop4yRUSvXpKKCYnIFVot\n7bVr17ZHDrpOkqSc8WEv6rw8paj1euVls+xFHRXF/yUR3Uq09dMmapHVCqSl1Z/x8dNPyvqGl5eM\nKVOsmDpVxB13cMMlolsZS1vjqquBvXuVaXr3bgNKS5X16cBAGffeqxT1mDEiT8sj6iBY2hpUUgJ8\n/rlS1Pv315/xERUlYdYspaiHDas/44OIOg6WtkYUFAiOZY+GZ3z06lX/g8SBAyW+IC1RB8fSVtGp\nU8C775qwc6cBJ0/Wn/Hxs5/ZT82zIjaWP0gkonos7XZWWgp89JER775rxKlTAOABg0HGqFH1Z3x0\n6sSiJqLmsbTbgSwDR4/qsGGDCZ9+akBtrQCjUcbMmcCECTWYOFHkq7cQkVNY2m2orEyZqjduNOLU\nKWX5o0cPCXPnmjFnjoi+fX1RVCSqnJKI3AlL28VkGTh2TIeNG0345BPlzA+jUcZdd1kxd64VI0bY\nePk4Ed0wlraLlJUBmzcbsWFD/VTdrZuEuXMtuP9+K8LCuE5NRDePpX0TZBk4cUJZq962TZmqDQYZ\nM2YoU/XIkZyqici1WNo3oLy8fqrOyWk8Vd93nxXh4ZyqiahtsLSdJMtAeroOGzYYsW2bEdXVylQ9\nfboV8+ZZMWoUp2oianss7VZUVChT9caNRmRnK1N11671U3VEBKdqImo/LO1m2F/oduNGI7ZsUaZq\nvV7GtGnKVD16NKdqIlIHS7uBysr6qTorq36q/vnPlTNAOFUTkdpY2gAyMpS16o8/rp+qp05Vpuox\nYzhVE5F2dNjSrqwEtmxRzgDJzFSm6i5dJCxdasEDD1gRGcmpmoi0p8OVdmamDu+8o6xVV1UpU/Xk\nyVbMn69M1Xp9649BRKSWDlHalZXAtm3A3//u7dgCtXNnCYsXK1M1d9UjIndxS5d2eTnwxhsmvP66\nCeXlgE6nw+TJylr12LGcqonI/dySpV1ZCbz1lgl/+5sJpaUCQkIkrF4tICWliq9OTkRu7ZYq7epq\nYN06I/72NxOuXNEhMFDGc8+ZsXChBT16+KGoiIVNRO7tlijt2lpgwwYj/vxnE4qKdPD3l5Gaasai\nRRb4+6udjojIddy6tM1m4N13lbIuLNTBx0fG8uVmPPaYha8EQ0S3JLcsbasVeP99I/70JxMKCnTw\n9paxZIkZv/qVFSEhXAIholuXW5W2KAIffWTA2rUeOHdOB09PGY89ZsGSJRa+yAARdQhuUdo2G7Bl\niwF//KMHzp7VwWSS8cgjFixbZuF+IETUoWi6tCUJ+PRTA1591YTcXD2MRhkPPWTB449beOoeEXVI\nmixtSQJ27lTK+tQpPfR6GT//uVLWXbuyrImo49JUacsy8MUXerz8sgeys/XQ6WTMmWPF8uVm9OjB\nsiYi0kRpyzKwd69S1unpegiCjLvvtuLJJ82IjWVZExHZqVrasgwcPKiU9dGjykYgM2ZY8eSTFvTp\nI6kZjYhIk1Qr7UOH9HjpJRMOHVIiTJlixVNPWdC/P8uaiOha2r20jx7V4aWXPHDwoPKtJ04UkZpq\nxoABLGsiotY49UJaBw4cwOTJkzFp0iS88cYbN/SNTpzQ4b77vDBtmg8OHjRgzBgRu3ZV4b33aljY\nREROanXSliQJL7zwAtavX4/w8HDcc889GD9+PGJiYpz6BllZOrzyigc+/1z5ViNGiEhNtWDYMNvN\nJSci6oBaLe3MzEx069YNnTt3BgBMmzYNe/bsabW0c3J0ePVVE3bsMAIAhg4V8fTTFowYwbImIrpR\nrZb2xYsX0alTJ8fHERERyMrKavE+990HfPihN2RZwKBBNjz9tBmjR9sgCDcfmIioI2u1tGX5+s+T\n3rQJGDBAwtNPmzF+PMuaiMhVWi3tyMhIXLhwwfHxxYsXER4e3uJ9lJ7XA/C+yXiuFRbmp3aEqzCT\nc7SYCdBmLmZyjhYzOaPVs0cSEhJw7tw5FBQUwGKxYMeOHRg/fnx7ZCMioiZanbT1ej1WrVqFhx9+\nGLIs45577nH6zBEiInItQb6RRWsiIlKFUxfXEBGRNrC0iYjcCEubiMiNuHTDqAMHDmDNmjWQZRmz\nZs3CokWLXPnwN2TlypXYt28fQkJCsH37drXjAAAKCwuRmpqKy5cvQ6/XY/bs2Zg3b56qmSwWCx58\n8EFYrVbYbDZMmjQJixcvVjWTnSRJmDVrFiIiIvD666+rHQfjxo2Dr68vdDodDAYDNm/erHYkVFRU\n4LnnnkNubi50Oh3WrFmDAQMGqJrp7NmzeOKJJyAIAmRZxvnz57Fs2TLV/6yvX78emzdvhiAI6NWr\nF1588UWYTCZVM73zzjuOP0et9oHsIjabTZ4wYYKcn58vWywWecaMGXJeXp6rHv6GHT16VM7JyZGn\nT5+udhSHS5cuyTk5ObIsy3JlZaV8xx13aOLXqrq6WpZlWRZFUZ49e7ackZGhciLF22+/La9YsUJ+\n9NFH1Y4iy7Isjxs3Ti4tLVU7RiNPP/20vHnzZlmWZdlqtcoVFRUqJ2rMZrPJycnJ8oULF1TNUVhY\nKI8bN042m82yLMvysmXL5K1bt6qa6fTp0/L06dNls9ksi6IoP/TQQ/KPP/54zeNdtjzScI8So9Ho\n2KNEbYMHD4a/v7/aMRoJCwtDfHw8AMDHxwcxMTG4dOmSyqkALy8vAMrULYqiymkUhYWF2L9/P2bP\nnq12FAdZliFJ2tmZsrKyEseOHcOsWbMAAAaDAb6+viqnaiwtLQ1du3ZttCWGWiRJQk1NDURRRG1t\nbasXC7a1M2fOYODAgTCZTNDr9bj99tuxe/fuax7vstJubo8SLRSR1uXn5+O7775DYmKi2lEgSRJS\nUlKQnJyM5ORkTWRas2YNUlNTIWhoLwRBELBw4ULMmjULH374odpxkJ+fj6CgIDz77LOYOXMmVq1a\nhdraWrVjNbJz505MmzZN7RiIiIjAggULMGbMGIwaNQp+fn5ISkpSNVNcXByOHj2KsrIy1NTU4MCB\nA/jpp5+uebzLSlvm6d7XraqqCkuXLsXKlSvh4+OjdhzodDps27YNBw4cQEZGBvLy8lTNs2/fPoSG\nhiI+Pl5Tf74++OADbNmyBW+++Sbee+89HDt2TNU8oigiJycHDzzwALZu3QpPT88b3ve+LVitVnz5\n5ZeYMmWK2lFQXl6OPXv2YO/evTh48CCqq6tV/1lXTEwMfvGLX2DBggVYtGgR+vTpA4Ph2j9udFlp\n38geJR2ZKIpYunQp7rrrLkyYMEHtOI34+vpiyJAhOHjwoKo5Tpw4gS+//BLjx4/HihUrcPjwYaSm\npqqaCVCWtwAgODgYEydObHXXy7YWGRmJyMhIJCQkAAAmTZqEnJwcVTM1dODAAfTr1w/BwcFqR0Fa\nWhqio6MRGBgIvV6PiRMnIj09Xe1YmDVrFrZs2YKNGzciICAA3bp1u+axLittLe9RoqUpzW7lypWI\njY3F/Pnz1Y4CACguLkZFRQUAoLa2FocOHULPnj1VzbR8+XLs27cPe/bswWuvvYahQ4filVdeUTVT\nTU0NqqqqAADV1dX46quvEBcXp2qm0NBQdOrUCWfPngUAfPPNN5raamLHjh2YPn262jEAAFFRUcjI\nyIDZbIYsy5r5tSouLgYAXLhwAbt3727x18tlp/xpdY8S+4RWWlqKMWPGYMmSJY4f2Kjl+PHj2L59\nO3r16oWUlBQIgoAnnngCo0aNUi1TUVERnnnmGUiSBEmSMHXqVIwePVq1PFp1+fJlLF68GIIgwGaz\n4c4778SIESPUjoXnn38eTz75JERRRHR0NF588UW1IwFQBoC0tDT87ne/UzsKACAxMRGTJk1CSkoK\nDAYD+vbti3vvvVftWFiyZAnKyspgMBjwm9/8Bn5+196BkHuPEBG5EV4RSUTkRljaRERuhKVNRORG\nWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu5P8D+7Wym3BFpegAAAAASUVORK5CYII=\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f5be4b8ec50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "model = Model()\n", + "\n", + "# Collect the history of W-values and b-values to plot later\n", + "Ws, bs = [], []\n", + "epochs = range(10)\n", + "for epoch in epochs:\n", + " Ws.append(model.W.numpy())\n", + " bs.append(model.b.numpy())\n", + " current_loss = loss(model(inputs), outputs)\n", + "\n", + " train(model, inputs, outputs, learning_rate=0.1)\n", + " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n", + " (epoch, Ws[-1], bs[-1], current_loss))\n", + "\n", + "# Let's plot it all\n", + "plt.plot(epochs, Ws, 'r',\n", + " epochs, bs, 'b')\n", + "plt.plot([TRUE_W] * len(epochs), 'r--',\n", + " [TRUE_b] * len(epochs), 'b--')\n", + "plt.legend(['W', 'b', 'true W', 'true_b'])\n", + "plt.show()\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vPnIVuaSJwWz" + }, + "source": [ + "## Next Steps\n", + "\n", + "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n", + "\n", + "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n", + "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n", + "\n", + "The [next tutorial](TODO) will cover these higher level APIs." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "Training Models", + "provenance": [], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5749f22ac58e0a012ed7e3fec4dfe2913d3f8273 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb @@ -0,0 +1,551 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "pwX7Fii1rwsJ" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "UEu3q4jmpKVT" + }, + "source": [ + "# High level API\n", + "\n", + "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zSFfVVjkrrsI" + }, + "source": [ + "## Layers: common sets of useful operations\n", + "\n", + "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", + "\n", + "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", + "\n", + "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "8PyXlPl-4TzQ" + }, + "outputs": [], + "source": [ + "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", + "# simply construct the object. Most layers take as a first argument the number\n", + "# of output dimensions / channels.\n", + "layer = tf.keras.layers.Dense(100)\n", + "# The number of input dimensions is often unnecessary, as it can be inferred\n", + "# the first time the layer is used, but it can be provided if you want to \n", + "# specify it manually, which is useful in some complex models.\n", + "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Fn69xxPO5Psr" + }, + "source": [ + "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", + "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 204 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 244, + "status": "ok", + "timestamp": 1527783641557, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "E3XKNknP5Mhb", + "outputId": "c5d52434-d980-4488-efa7-5660819d0207" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u003ctf.Tensor: id=30, shape=(10, 10), dtype=float32, numpy=\n", + "array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)\u003e" + ] + }, + "execution_count": 3, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# To use a layer, simply call it.\n", + "layer(tf.zeros([10, 5]))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 320, + "status": "ok", + "timestamp": 1527783642457, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "Wt_Nsv-L5t2s", + "outputId": "f0d96dce-0128-4080-bfe2-0ee6fbc0ad90" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", + " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", + " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", + " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", + " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", + " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", + " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", + " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", + " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", + " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e]" + ] + }, + "execution_count": 4, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Layers have many useful methods. For example, you can inspect all variables\n", + "# in a layer by calling layer.variables. In this case a fully-connected layer\n", + "# will have variables for weights and biases.\n", + "layer.variables" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 226, + "status": "ok", + "timestamp": 1527783643252, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "6ilvKjz8_4MQ", + "outputId": "f647fced-c2d7-41a3-c237-242036784665" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", + " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", + " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", + " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", + " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", + " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", + " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", + " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", + " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", + " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e)" + ] + }, + "execution_count": 5, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# The variables are also accessible through nice accessors\n", + "layer.kernel, layer.bias" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "O0kDbE54-5VS" + }, + "source": [ + "## Implementing custom layers\n", + "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n", + " * `__init__` , where you can do all input-independent initialization\n", + " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", + " * `call`, where you do the forward computation\n", + "\n", + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 391 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 251, + "status": "ok", + "timestamp": 1527783661512, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "5Byl3n1k5kIy", + "outputId": "6e7f9285-649a-4132-82ce-73ea92f15862" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)\n", + "[\u003ctf.Variable 'my_dense_layer_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + "array([[-0.4011991 , 0.22458655, -0.33237562, -0.25117266, 0.33528614,\n", + " -0.01392961, 0.58580834, -0.16346583, 0.28465688, -0.47191954],\n", + " [-0.52922136, 0.22416979, -0.58209574, -0.60914612, 0.05226624,\n", + " -0.18325993, 0.5591442 , -0.24718609, 0.37148207, 0.40475875],\n", + " [ 0.16912812, -0.47618777, -0.38989353, 0.30105609, -0.08085585,\n", + " 0.44758242, 0.545829 , 0.51421839, 0.11063248, 0.20159996],\n", + " [ 0.34073615, -0.59835428, 0.06498981, -0.44489855, -0.34302285,\n", + " 0.20969599, 0.35527444, -0.03173476, -0.22227573, 0.09303057],\n", + " [ 0.41764337, -0.06435019, -0.52509922, -0.39957345, 0.56811184,\n", + " 0.23481232, -0.61666459, 0.31144124, -0.11532354, -0.42421889]], dtype=float32)\u003e]\n" + ] + } + ], + "source": [ + "class MyDenseLayer(tf.keras.layers.Layer):\n", + " def __init__(self, num_outputs):\n", + " super(MyDenseLayer, self).__init__()\n", + " self.num_outputs = num_outputs\n", + " \n", + " def build(self, input_shape):\n", + " self.kernel = self.add_variable(\"kernel\", \n", + " shape=[input_shape[-1].value, \n", + " self.num_outputs])\n", + " \n", + " def call(self, input):\n", + " return tf.matmul(input, self.kernel)\n", + " \n", + "layer = MyDenseLayer(10)\n", + "print(layer(tf.zeros([10, 5])))\n", + "print(layer.variables)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "tk8E2vY0-z4Z" + }, + "source": [ + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n", + "\n", + "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Qhg4KlbKrs3G" + }, + "source": [ + "## Models: composing layers\n", + "\n", + "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", + "\n", + "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 190 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 420, + "status": "ok", + "timestamp": 1527783698512, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "N30DTXiRASlb", + "outputId": "a8b23a8e-5cf9-4bbf-f93b-6c763d74e2b3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[[[ 0. 0. 0.]\n", + " [ 0. 0. 0.]\n", + " [ 0. 0. 0.]]\n", + "\n", + " [[ 0. 0. 0.]\n", + " [ 0. 0. 0.]\n", + " [ 0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n", + "['resnet_identity_block_1/conv2d_3/kernel:0', 'resnet_identity_block_1/conv2d_3/bias:0', 'resnet_identity_block_1/batch_normalization_3/gamma:0', 'resnet_identity_block_1/batch_normalization_3/beta:0', 'resnet_identity_block_1/conv2d_4/kernel:0', 'resnet_identity_block_1/conv2d_4/bias:0', 'resnet_identity_block_1/batch_normalization_4/gamma:0', 'resnet_identity_block_1/batch_normalization_4/beta:0', 'resnet_identity_block_1/conv2d_5/kernel:0', 'resnet_identity_block_1/conv2d_5/bias:0', 'resnet_identity_block_1/batch_normalization_5/gamma:0', 'resnet_identity_block_1/batch_normalization_5/beta:0', 'resnet_identity_block_1/batch_normalization_3/moving_mean:0', 'resnet_identity_block_1/batch_normalization_3/moving_variance:0', 'resnet_identity_block_1/batch_normalization_4/moving_mean:0', 'resnet_identity_block_1/batch_normalization_4/moving_variance:0', 'resnet_identity_block_1/batch_normalization_5/moving_mean:0', 'resnet_identity_block_1/batch_normalization_5/moving_variance:0']\n" + ] + } + ], + "source": [ + "class ResnetIdentityBlock(tf.keras.Model):\n", + " def __init__(self, kernel_size, filters):\n", + " super(ResnetIdentityBlock, self).__init__(name='')\n", + " filters1, filters2, filters3 = filters\n", + "\n", + " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", + " self.bn2a = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", + " self.bn2b = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", + " self.bn2c = tf.keras.layers.BatchNormalization()\n", + "\n", + " def call(self, input_tensor, training=False):\n", + " x = self.conv2a(input_tensor)\n", + " x = self.bn2a(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2b(x)\n", + " x = self.bn2b(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2c(x)\n", + " x = self.bn2c(x, training=training)\n", + "\n", + " x += input_tensor\n", + " return tf.nn.relu(x)\n", + "\n", + " \n", + "block = ResnetIdentityBlock(1, [1, 2, 3])\n", + "print(block(tf.zeros([1, 2, 3, 3])))\n", + "print([x.name for x in block.variables])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wYfucVw65PMj" + }, + "source": [ + "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "base_uri": "https://localhost:8080/", + "height": 153 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 361, + "status": "ok", + "timestamp": 1526674830777, + "user": { + "displayName": "Alexandre Passos", + "photoUrl": "//lh4.googleusercontent.com/-kmTTWXEgAPw/AAAAAAAAAAI/AAAAAAAAAC0/q_DoOzKGwds/s50-c-k-no/photo.jpg", + "userId": "108023195365833072773" + }, + "user_tz": 420 + }, + "id": "L9frk7Ur4uvJ", + "outputId": "882e9076-b6d9-4380-bb1e-7c6b57d54c39" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u003ctf.Tensor: id=1423, shape=(1, 2, 3, 3), dtype=float32, numpy=\n", + "array([[[[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]],\n", + "\n", + " [[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]]]], dtype=float32)\u003e" + ] + }, + "execution_count": 26, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(2, 1, \n", + " padding='same'),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(3, (1, 1)),\n", + " tf.keras.layers.BatchNormalization()])\n", + "my_seq(tf.zeros([1, 2, 3, 3]))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "c5YwYcnuK-wc" + }, + "source": [ + "# Next steps\n", + "\n", + "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "4 - High level API - TensorFlow Eager.ipynb", + "provenance": [], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index b8f352d5f5b72ffb8ae81a2bb72974c7fd65bd5a..b14ef1df8ff4c660b9b6f2abfd5df6572d10b1e8 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -49,15 +49,17 @@ def random_batch(batch_size, data_format): return images, one_hot -def train_one_step(model, images, labels, optimizer): - - with tfe.GradientTape() as tape: +def compute_gradients(model, images, labels): + with tf.GradientTape() as tape: logits = model(images, training=True) loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) tf.contrib.summary.scalar(name='loss', tensor=loss) - grads = tape.gradient(loss, model.variables) - optimizer.apply_gradients(zip(grads, model.variables)) + return tape.gradient(loss, model.variables) + + +def apply_gradients(model, optimizer, gradients): + optimizer.apply_gradients(zip(gradients, model.variables)) class ResNet50Test(tf.test.TestCase): @@ -114,7 +116,8 @@ class ResNet50Test(tf.test.TestCase): with tf.device(device), tfe.execution_mode(execution_mode): optimizer = tf.train.GradientDescentOptimizer(0.1) images, labels = random_batch(2, data_format) - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) self.assertEqual(320, len(model.variables)) tfe.async_wait() events = summary_test_util.events_from_logdir(logdir) @@ -138,14 +141,16 @@ class ResNet50Test(tf.test.TestCase): # garbage to be collected. The hope is that this is a build-only effect, # and a subsequent training loop will create nothing which needs to be # collected. - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) gc.collect() previous_gc_debug_flags = gc.get_debug() gc.set_debug(gc.DEBUG_SAVEALL) for _ in range(2): # Run twice to ensure that garbage that is created on the first # iteration is no longer accessible. - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) gc.collect() # There should be no garbage requiring collection. self.assertEqual(0, len(gc.garbage)) @@ -180,9 +185,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): return (16, 32, 64) if tf.DeviceSpec.from_string(device.name).device_type == 'TPU': - # TODO(iga): Training fails with batch size of 16, probably because of - # no layout optimizations with op-by-op mode. Investigate more. - return (8,) + return (32,) return (16, 32) def _report(self, label, start, num_iters, device, batch_size, data_format): @@ -248,18 +251,21 @@ class ResNet50Benchmarks(tf.test.Benchmark): device, data_format = device_and_format for batch_size in self._train_batch_sizes(): (images, labels) = random_batch(batch_size, data_format) - num_burn = 3 - num_iters = 10 model = resnet50.ResNet50(data_format) + optimizer = tf.train.GradientDescentOptimizer(0.1) + apply_grads = apply_gradients if defun: model.call = tfe.defun(model.call, compiled=compiled) - optimizer = tf.train.GradientDescentOptimizer(0.1) + apply_grads = tfe.defun(apply_gradients, compiled=compiled) + num_burn = 3 + num_iters = 10 with tf.device(device): iterator = make_iterator((images, labels)) for _ in xrange(num_burn): (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) if execution_mode: tfe.async_wait() self._force_device_sync() @@ -268,7 +274,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): start = time.time() for _ in xrange(num_iters): (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) if execution_mode: tfe.async_wait() self._force_device_sync() diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 492adbe1d80941f9df96d6636e4933d11239408e..5ee2176154ec7011dcb3d7b384a86213e778014f 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -152,7 +152,7 @@ class RNNColorbot(tf.keras.Model): self.label_dimension = label_dimension self.keep_prob = keep_prob - self.cells = self._add_cells( + self.cells = tf.contrib.checkpoint.List( [tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes]) self.relu = layers.Dense( label_dimension, activation=tf.nn.relu, name="relu") @@ -204,14 +204,6 @@ class RNNColorbot(tf.keras.Model): hidden_states = tf.gather_nd(chars, indices) return self.relu(hidden_states) - def _add_cells(self, cells): - # "Magic" required for keras.Model classes to track all the variables in - # a list of layers.Layer objects. - # TODO(ashankar): Figure out API so user code doesn't have to do this. - for i, c in enumerate(cells): - setattr(self, "cell-%d" % i, c) - return cells - def loss(labels, predictions): """Computes mean squared loss.""" diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py index 75b342ba78bd5de5c2827296f6fba01ffa86d560..b7d8395e277b526ba40ccafa323ba453a8667b62 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py @@ -67,5 +67,5 @@ class RNNColorbotTest(tf.test.TestCase): if __name__ == "__main__": - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index be5d60449d7e08c99cc28e76befce56f468c77fd..c2340a293a80924f2dfa90e2fb23134b0f1feb6b 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -50,7 +50,7 @@ class RNN(tf.keras.Model): def __init__(self, hidden_dim, num_layers, keep_ratio): super(RNN, self).__init__() self.keep_ratio = keep_ratio - self.cells = self._add_cells([ + self.cells = tf.contrib.checkpoint.List([ tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim) for _ in range(num_layers) ]) @@ -74,14 +74,6 @@ class RNN(tf.keras.Model): # tuple (output, output_states). return [input_seq] - def _add_cells(self, cells): - # "Magic" required for keras.Model classes to track all the variables in - # a list of Layer objects. - # TODO(ashankar): Figure out API so user code doesn't have to do this. - for i, c in enumerate(cells): - setattr(self, "cell-%d" % i, c) - return cells - class Embedding(layers.Layer): """An Embedding layer.""" @@ -304,7 +296,7 @@ def test_model(use_cudnn_rnn): def main(_): - tfe.enable_eager_execution() + tf.enable_eager_execution() if not FLAGS.data_path: raise ValueError("Must specify --data-path") diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..638c57d1c92c1dce0ef9e73e9a6ac2369358080b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +cuda_py_test( + name = "scan_test", + size = "small", + srcs = ["scan_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "scan_graph_test", + size = "small", + srcs = ["scan_graph_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b8c8941ec411912f3089315d038fc4bcd049ae --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit test for tf.scan under graph mode execution.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +import tensorflow as tf + + +class ScanBenchmark(tf.test.Benchmark): + + def runScan(self, n): + elems = np.arange(n) + start_time = time.time() + sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) + with tf.Session() as sess: + sess.run(sum_op) + wall_time = time.time() - start_time + + self.report_benchmark( + name='scan', + iters=n, + wall_time=wall_time) + + def benchmarkScan16000(self): + self.runScan(16000) + + def benchmarkScan32000(self): + self.runScan(32000) + + def benchmarkScan64000(self): + self.runScan(64000) + + def benchmarkScan128000(self): + self.runScan(128000) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a02fc24c79dae6c2565db8b138b1d7391d169ed8 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/scan_test.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit test for tf.scan under eager execution.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +import tensorflow as tf + + +class ScanBenchmark(tf.test.Benchmark): + + def runScan(self, n): + elems = np.arange(n) + start_time = time.time() + _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) + wall_time = time.time() - start_time + + self.report_benchmark( + name='scan', + iters=n, + wall_time=wall_time) + + def benchmarkScan16000(self): + self.runScan(16000) + + def benchmarkScan32000(self): + self.runScan(32000) + + def benchmarkScan64000(self): + self.runScan(64000) + + def benchmarkScan128000(self): + self.runScan(128000) + + +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 1e4746d01ca1a8d13162844bc064c479c7184237..8ac553e0ae71382966d03d9ef4429adf5137b369 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -36,8 +36,8 @@ from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=g-bad-import-order diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 907f9204c2d31a652ca2a0539a23db4722b4e154..c947ed9dcc415670a820f8a5cd9eaaf07334cfc3 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -25,12 +25,13 @@ from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training import checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable _to_replace = re.compile("[^A-Za-z0-9.]") @@ -367,6 +368,9 @@ class Accuracy(Mean): Returns: The arguments, for easy chaining. """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") matches = math_ops.equal(labels, predictions) matches = math_ops.cast(matches, dtypes.float64) super(Accuracy, self).call(matches, weights=weights) diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index f0fe4ce8c53bb80c03a3f0de37078bcdb975a0b4..02ee05487515b81bfae70d02c1dfdb6d816b77c7 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -26,12 +26,13 @@ from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import util as checkpointable_utils class MetricsTest(test.TestCase): @@ -117,6 +118,11 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) + def testAccuracyDifferentShapes(self): + m = metrics.Accuracy() + with self.assertRaises(errors.InvalidArgumentError): + m([[0], [0]], [0, 1]) + def testWeightedAccuracy(self): m = metrics.Accuracy() # 1 correct, total weight of 2 @@ -146,8 +152,6 @@ class MetricsTest(test.TestCase): self.assertAllEqual(2.0, m2.result()) def testNamesWithSpaces(self): - # Verify two metrics with the same class and name don't - # accidentally share state. m1 = metrics.Mean("has space") m1(0) self.assertEqual(m1.name, "has space") @@ -186,8 +190,8 @@ class MetricsTest(test.TestCase): self.assertEqual(self.evaluate(value), 2.5) def testTwoMeansGraph(self): - # Verify two metrics with the same class and name don't - # accidentally share state. + # Verify two metrics with the same name in the same graph raises a + # ValueError. with context.graph_mode(): m1 = metrics.Mean() m1(0) diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 44828bea50c660815e457f21a1990cd706c40876..f801d9a47b2f831a48d9b6335c69612c1356d800 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -23,9 +23,8 @@ import os import weakref from tensorflow.python.eager import context -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops -from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer +from tensorflow.python.keras.engine import base_layer as keras_base_layer from tensorflow.python.layers import base from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -33,6 +32,7 @@ from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils # pylint: disable=protected-access # Explanation for protected-access disable: Network has lots of same-class and @@ -545,10 +545,10 @@ class Sequential(Network): def add(self, layer_func): if isinstance(layer_func, base.Layer): - args = estimator_util.fn_args(layer_func.call) + args = function_utils.fn_args(layer_func.call) self.track_layer(layer_func) elif callable(layer_func): - args = estimator_util.fn_args(layer_func) + args = function_utils.fn_args(layer_func) else: raise TypeError( "Sequential.add() takes only tf.layers.Layer objects or callables; " diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 6a51d03de52914d2ad0ac3ad05d1ba01d856ad9a..c92bd15b253b67a3301cd562046a4467e1bf877d 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -30,8 +30,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: disable=not-callable diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 4032e755f6e7dea9dcb42587f14e8386e5db2338..90a3711475719a7f991473c6c9067da1e76ab9f2 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -60,15 +60,9 @@ class SaverTest(test.TestCase): def testSameNameNoClobbering(self): with ops.device(self._dev()): - # Note that this test purposefully uses Graphs rather than - # IsolateTest. Users are more likely to accidentally create the same - # variable name this way. - first_graph = ops.Graph() - with first_graph.as_default(): - v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1') - with ops.Graph().as_default(): - v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1') - saver = _saver.Saver([v1_first_graph, v1_second_graph]) + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + v2 = resource_variable_ops.ResourceVariable(2.0, name='v1') + saver = _saver.Saver([v1, v2]) ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') with self.assertRaisesRegexp(ValueError, 'v1'): saver.save(ckpt_prefix) @@ -126,12 +120,11 @@ class SaverTest(test.TestCase): saver = _saver.Saver([v1]) saver.save(ckpt_prefix) - with ops.Graph().as_default(): - saver = _saver.Saver([v1]) - with _saver.restore_variables_on_create(ckpt_prefix): - # Value is from checkpoint, but not from argument. - ret, _ = model(2.0) - self.assertEqual(ret.numpy(), 1.0) + saver = _saver.Saver([v1]) + with _saver.restore_variables_on_create(ckpt_prefix): + # Value is from checkpoint, but not from argument. + ret, _ = model(2.0) + self.assertEqual(ret.numpy(), 1.0) def testRestoreNotFound(self): with ops.device(self._dev()): @@ -184,17 +177,17 @@ class SaverTest(test.TestCase): 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) # reset the graph and reload on create, so that 1 + 2 = 3 - with ops.Graph().as_default(): - with _saver.restore_variables_on_create(ckpt_prefix): - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def model2(x): - v = variable_scope.get_variable( - 'v', initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - self.assertEqual( - 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + ops.reset_default_graph() + with _saver.restore_variables_on_create(ckpt_prefix): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model2(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + self.assertEqual( + 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) class GetOptimizerTests(test.TestCase): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 79dd117854e5fe9f066f671d8ce62e08579e0ed9..fee9db46fa4f79d7dd613436726e8ddad51faf1c 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -115,14 +115,15 @@ from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes +from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable import Checkpointable -from tensorflow.python.training.checkpointable_utils import CheckpointableSaver -from tensorflow.python.training.checkpointable_utils import Checkpoint +from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.util import CheckpointableSaver +from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index e9a68801efccc1a74450359eb672caaa51ad73e8..1937ffb583bc727df76470d072b35fb3c9acaa88 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -14,12 +14,14 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ + ":baseline", ":boosted_trees", ":dnn", ":dnn_linear_combined", ":export", ":extenders", ":head", + ":hooks", ":linear", ":logit_fns", ":multi_head", @@ -29,6 +31,49 @@ py_library( ], ) +py_library( + name = "baseline", + srcs = ["python/estimator/baseline.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:baseline", + ], +) + +py_test( + name = "baseline_test", + size = "small", + srcs = ["python/estimator/baseline_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":baseline", + ":head", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:metric_keys", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "boosted_trees", srcs = ["python/estimator/boosted_trees.py"], @@ -267,6 +312,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", + "//tensorflow/python:variables", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", @@ -277,6 +323,37 @@ py_test( ], ) +py_library( + name = "hooks", + srcs = [ + "python/estimator/hooks.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "hooks_test", + size = "medium", + srcs = ["python/estimator/hooks_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":hooks", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "linear", srcs = ["python/estimator/linear.py"], @@ -322,9 +399,9 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework_ops", + "//tensorflow/python:util", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:linear", - "//tensorflow/python/estimator:util", ], ) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index ec502f86ddb724c403e00bd21da4f7e970849d4e..788ac5ca7046d6dd30a3d5520b243944532622fa 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -19,12 +19,14 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.estimator.python.estimator.baseline import * from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * +from tensorflow.contrib.estimator.python.estimator.hooks import * from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * @@ -39,12 +41,14 @@ _allowed_symbols = [ 'binary_classification_head', 'clip_gradients_by_norm', 'forward_features', + 'InMemoryEvaluatorHook', 'logistic_regression_head', 'multi_class_head', 'multi_head', 'multi_label_head', 'poisson_regression_head', 'regression_head', + 'BaselineEstimator', 'DNNEstimator', 'DNNLinearCombinedEstimator', 'LinearEstimator', diff --git a/tensorflow/contrib/estimator/python/estimator/baseline.py b/tensorflow/contrib/estimator/python/estimator/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..beffbee73064b9ef425b115317c43e29477b19af --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/baseline.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Baseline estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import baseline + + +class BaselineEstimator(estimator.Estimator): + """An estimator that can establish a simple baseline. + + The estimator uses a user-specified head. + + This estimator ignores feature values and will learn to predict the average + value of each label. E.g. for single-label classification problems, this will + predict the probability distribution of the classes as seen in the labels. + For multi-label classification problems, it will predict the ratio of examples + that contain each class. + + Example: + + ```python + + # Build baseline multi-label classifier. + estimator = BaselineEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3)) + + # Input builders + def input_fn_train: # returns x, y (where y represents label's class index). + pass + + def input_fn_eval: # returns x, y (where y represents label's class index). + pass + + # Fit model. + estimator.train(input_fn=input_fn_train) + + # Evaluates cross entropy between the test and train labels. + loss = classifier.evaluate(input_fn=input_fn_eval)["loss"] + + # For each class, predicts the ratio of training examples that contain the + # class. + predictions = classifier.predict(new_samples) + + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` passed to the `head` constructor is not `None`, a feature + with `key=weight_column` whose value is a `Tensor`. + """ + + def __init__(self, + head, + model_dir=None, + optimizer='Ftrl', + config=None): + """Initializes a BaselineEstimator instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use + `FtrlOptimizer` with a default learning rate of 0.3. + config: `RunConfig` object to configure the runtime settings. + """ + def _model_fn(features, labels, mode, config): + return baseline._baseline_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + optimizer=optimizer, + config=config) + super(BaselineEstimator, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e3e670f7332811c1bfdaea65b0308ce59ade59 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py @@ -0,0 +1,430 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for baseline.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import baseline +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.python.client import session as tf_session +from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import optimizer +from tensorflow.python.training import saver + +# Names of variables created by model. +BIAS_NAME = 'baseline/bias' + + +def assert_close(expected, actual, rtol=1e-04, name='assert_close'): + with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: + expected = ops.convert_to_tensor(expected, name='expected') + actual = ops.convert_to_tensor(actual, name='actual') + rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) + rtol = ops.convert_to_tensor(rtol, name='rtol') + return check_ops.assert_less( + rdiff, + rtol, + data=('Condition expected =~ actual did not hold element-wise:' + 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff, + 'rtol = ', rtol,), + name=scope) + + +def save_variables_to_ckpt(model_dir): + init_all_op = [variables.global_variables_initializer()] + with tf_session.Session() as sess: + sess.run(init_all_op) + saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + + +def _baseline_estimator_fn( + weight_column=None, label_dimension=1, *args, **kwargs): + """Returns a BaselineEstimator that uses regression_head.""" + return baseline.BaselineEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), + *args, **kwargs) + + +class BaselineEstimatorEvaluationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_evaluation_batch(self): + """Tests evaluation for batch_size==2.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) + eval_metrics = baseline_estimator.evaluate( + input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the sum over batch = 9 + 9 = 18 + # Average loss is the average over batch = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 18., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_weights(self): + """Tests evaluation with weights.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + def _input_fn(): + features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))} + labels = ((10.,), (10.,)) + return features, labels + + baseline_estimator = _baseline_estimator_fn( + weight_column='weights', + model_dir=self._model_dir) + eval_metrics = baseline_estimator.evaluate(input_fn=_input_fn, steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the weighted sum over batch = 9 + 2*9 = 27 + # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 27., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_for_multi_dimensions(self): + label_dim = 2 + with ops.Graph().as_default(): + variables.Variable([46.0, 58.0], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn( + label_dimension=label_dim, + model_dir=self._model_dir) + input_fn = numpy_io.numpy_input_fn( + x={ + 'age': np.array([[2., 4., 5.]]), + }, + y=np.array([[46., 58.]]), + batch_size=1, + num_epochs=None, + shuffle=False) + eval_metrics = baseline_estimator.evaluate(input_fn=input_fn, steps=1) + + self.assertItemsEqual( + (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, + ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + + # Logit is bias which is [46, 58] + self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) + + +class BaselineEstimatorPredictTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_1d(self): + """Tests predict when all variables are one-dimensional.""" + with ops.Graph().as_default(): + variables.Variable([.2], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': np.array([[2.]])}, + y=None, + batch_size=1, + num_epochs=1, + shuffle=False) + predictions = baseline_estimator.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # x * weight + bias = 2. * 10. + .2 = 20.2 + self.assertAllClose([[.2]], predicted_scores) + + def testMultiDim(self): + """Tests predict when all variables are multi-dimenstional.""" + batch_size = 2 + label_dimension = 3 + with ops.Graph().as_default(): + variables.Variable( # shape=[label_dimension] + [.2, .4, .6], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + # x shape=[batch_size, x_dim] + x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predictions = baseline_estimator.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # score = bias, shape=[batch_size, label_dimension] + self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]], + predicted_scores) + + +class BaselineEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn, + input_dimension, label_dimension, prediction_length): + feature_columns = [ + feature_column_lib.numeric_column('x', shape=(input_dimension,)) + ] + est = _baseline_estimator_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + # TRAIN + # learn y = x + est.train(train_input_fn, steps=200) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores)) + + # PREDICT + predictions = np.array( + [x['predictions'] for x in est.predict(predict_input_fn)]) + self.assertAllEqual((prediction_length, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column_lib.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + prediction_length = batch_size + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=input_dimension, + label_dimension=label_dimension, + prediction_length=prediction_length) + + +class BaselineEstimatorTrainingTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s:0' % BIAS_NAME + ] + + def _minimize(loss, global_step=None, var_list=None): + trainable_vars = var_list or ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual(expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + if global_step is not None: + return distribute_lib.increment_var(global_step) + return control_flow_ops.no_op() + assert_loss = assert_close( + math_ops.to_float(expected_loss, name='expected'), + loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + if global_step is not None: + return distribute_lib.increment_var(global_step) + return control_flow_ops.no_op() + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def _assert_checkpoint(self, + label_dimension, + expected_global_step, + expected_bias=None): + shapes = { + name: shape + for (name, shape) in checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual(expected_global_step, + checkpoint_utils.load_variable(self._model_dir, + ops.GraphKeys.GLOBAL_STEP)) + + self.assertEqual([label_dimension], shapes[BIAS_NAME]) + if expected_bias is not None: + self.assertEqual(expected_bias, + checkpoint_utils.load_variable(self._model_dir, + BIAS_NAME)) + + def testFromScratch(self): + # Create BaselineRegressor. + label = 5. + age = 17 + # loss = (logits - label)^2 = (0 - 5.)^2 = 25. + mock_optimizer = self._mock_optimizer(expected_loss=25.) + baseline_estimator = _baseline_estimator_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_estimator.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=num_steps, + expected_bias=[0.]) + + def testFromCheckpoint(self): + # Create initial checkpoint. + bias = 7.0 + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable([bias], name=BIAS_NAME) + variables.Variable( + initial_global_step, + name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # logits = bias = 6. + # loss = (logits - label)^2 = (7 - 5)^2 = 4 + mock_optimizer = self._mock_optimizer(expected_loss=4.) + baseline_estimator = _baseline_estimator_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_estimator.train( + input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=initial_global_step + num_steps, + expected_bias=[bias]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index cf6e3329d2e27735d8759cc2ab3726e8c624c6ae..7ff25b95c079c7e06d29e874bcaa0d2c13e7167e 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -93,7 +93,7 @@ class DNNEstimator(estimator.Estimator): dropout=None, input_layer_partitioner=None, config=None): - """Initializes a `DNNClassifier` instance. + """Initializes a `DNNEstimator` instance. Args: head: A `_Head` instance constructed with a method such as diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py index e7e366a3f26fa60ea7867c128799fe358b027bdf..03cf6f107c1c5589522d7be4946562a466740b0e 100644 --- a/tensorflow/contrib/estimator/python/estimator/export.py +++ b/tensorflow/contrib/estimator/python/estimator/export.py @@ -60,38 +60,16 @@ def export_saved_model_for_mode( with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) + weights = graph.get_tensor_by_name(''linear/linear_model/age/weights') ... ``` - This method takes an input_receiver_fn and mode. For the mode passed in, - this method builds a new graph by calling the input_receiver_fn to obtain - feature and label `Tensor`s. Next, this method calls the `Estimator`'s - model_fn in the passed mode to generate the model graph based on - those features and labels, and restores the given checkpoint - (or, lacking that, the most recent checkpoint) into the graph. - Finally, it creates a timestamped export directory below the - export_dir_base, and writes a `SavedModel` into it containing - the `MetaGraphDef` for the given mode and its associated signatures. - - For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` - for each element of the export_outputs dict returned from the model_fn, - named using the same keys. One of these keys is always - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which - signature will be served when a serving request does not specify one. - For each signature, the outputs are provided by the corresponding - `ExportOutput`s, and the inputs are always the input receivers provided by - the serving_input_receiver_fn. + This method is a wrapper for _export_all_saved_models, and wraps a raw + input_receiver_fn in a dictionary to pass in to that function. + See _export_all_saved_models for full docs. - For training and evaluation, the train_op is stored in an extra collection, - and loss, metrics, and predictions are included in a SignatureDef for the - mode in question. - - Extra assets may be written into the SavedModel via the assets_extra - argument. This should be a dict, where each key gives a destination path - (including the filename) relative to the assets.extra directory. The - corresponding value gives the full path of the source file to be copied. - For example, the simple case of copying a single file without renaming it - is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + See tf.contrib.estimator.export_saved_model_for_mode for the currently + exposed version of this function. Args: estimator: an instance of tf.estimator.Estimator @@ -138,10 +116,39 @@ def export_all_saved_models( # pylint: disable=line-too-long """Exports requested train/eval/predict graphs as separate SavedModels. - This is a wrapper around export_saved_model_for_mode that accepts - multiple modes simultaneously and creates directories for each under - export_dir_base. See `Estimator.export_saved_model_for_mode` for - further details as to how the export works for each mode. + See tf.contrib.estimator.export_all_saved_models for the currently + exposed version of this function. + + For each mode passed in via the input_receiver_fn_map, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Only one of the modes is used for saving variables to the SavedModel + (order of preference: TRAIN, EVAL, then PREDICT), such that up to three + MetaGraphDefs are saved with a single set of variables in a single + SavedModel directory. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. Sample usage: ```python @@ -166,7 +173,7 @@ def export_all_saved_models( model_fn_lib.ModeKeys.PREDICT: serve_rcvr_fn, } - export_dirs = tf.contrib.estimator.export_all_saved_models( + export_dir = tf.contrib.estimator.export_all_saved_models( classifier, export_dir_base='my_model/', input_receiver_fn_map=rcvr_fn_map) @@ -175,8 +182,8 @@ def export_all_saved_models( # can be used for serving, analysis with TFMA, or directly loaded in. with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: - loader.load(sess, [tag_constants.TRAINING], - export_dirs[tf.estimator.ModeKeys.TRAIN]) + loader.load(sess, [tag_constants.TRAINING], export_dir) + weights = graph.get_tensor_by_name('linear/linear_model/age/weights') ... ``` diff --git a/tensorflow/contrib/estimator/python/estimator/export_test.py b/tensorflow/contrib/estimator/python/estimator/export_test.py index 89d02582e18e39ee35730e7674691ed9638a3e50..050821ee672f30a6926c4a0a0e48915515d9afd7 100644 --- a/tensorflow/contrib/estimator/python/estimator/export_test.py +++ b/tensorflow/contrib/estimator/python/estimator/export_test.py @@ -166,12 +166,9 @@ class EstimatorExportTest(test.TestCase): input_receiver_fn_map = { model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - self.assertEqual(len(export_dirs), 1) - # Restore, to validate that the export was well-formed. - export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.SERVING], export_dir) @@ -188,12 +185,9 @@ class EstimatorExportTest(test.TestCase): input_receiver_fn_map = { model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - self.assertEqual(len(export_dirs), 1) - # Restore, to validate that the export was well-formed. - export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) @@ -211,12 +205,9 @@ class EstimatorExportTest(test.TestCase): input_receiver_fn_map = { model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - self.assertEqual(len(export_dirs), 1) - # Restore, to validate that the export was well-formed. - export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.EVAL], export_dir) @@ -235,12 +226,9 @@ class EstimatorExportTest(test.TestCase): model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - self.assertEqual(len(export_dirs), 2) - # Restore, to validate that the export was well-formed. - export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) @@ -249,7 +237,7 @@ class EstimatorExportTest(test.TestCase): self.assertFalse('eval_multiplied' in graph_ops) self.assertTrue('feature_x' in graph_ops) self.assertTrue('weight' in graph_ops) - export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.EVAL], export_dir) @@ -270,12 +258,11 @@ class EstimatorExportTest(test.TestCase): model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) # Restore, to validate that the export was well-formed. - for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items(): - export_dir = export_dirs[mode] + for tag_set in model_fn_lib.EXPORT_TAG_MAP.values(): with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, tag_set, export_dir) @@ -292,10 +279,9 @@ class EstimatorExportTest(test.TestCase): model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) @@ -303,7 +289,6 @@ class EstimatorExportTest(test.TestCase): self.assertTrue('later_var' in graph_ops) self.assertTrue('weight' in graph_ops) - export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.SERVING], export_dir) @@ -319,10 +304,9 @@ class EstimatorExportTest(test.TestCase): model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) @@ -332,7 +316,6 @@ class EstimatorExportTest(test.TestCase): collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertEqual(3, collection_vars[-1].eval()) - export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.SERVING], export_dir) @@ -360,16 +343,15 @@ class EstimatorExportTest(test.TestCase): # Perform the export. export_dir_base = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('export')) - export_dirs = contrib_export.export_all_saved_models( + export_dir = contrib_export.export_all_saved_models( est, export_dir_base, input_receiver_fn_map) # Check that all the files are in the right places. self.assertTrue(gfile.Exists(export_dir_base)) - for _, export_dir in export_dirs.items(): - self._validate_exported_files(export_dir) + self._validate_exported_files(export_dir) - return export_dirs, tmpdir + return export_dir, tmpdir def _validate_exported_files(self, export_dir): self.assertTrue(gfile.Exists(export_dir)) diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index 201699ed775f701bc9f215fff11a688175d51645..bf08be09e7baf63e507a6a4db6a91e7b6bb20b74 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -22,12 +22,12 @@ import six from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export.export_output import PredictOutput from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.util import function_utils _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) @@ -330,7 +330,7 @@ class _TransformGradients(optimizer_lib.Optimizer): def _verify_metric_fn_args(metric_fn): - args = set(estimator_util.fn_args(metric_fn)) + args = set(function_utils.fn_args(metric_fn)) invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % @@ -339,7 +339,7 @@ def _verify_metric_fn_args(metric_fn): def _call_metric_fn(metric_fn, features, labels, predictions, config): """Calls metric fn with proper arguments.""" - metric_fn_args = estimator_util.fn_args(metric_fn) + metric_fn_args = function_utils.fn_args(metric_fn) kwargs = {} if 'features' in metric_fn_args: kwargs['features'] = features diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 109fdd3883427ab93fd289b9621141f5281bd7d0..9594e5132fd20dadea118fd1dd6768feb7fd7fff 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six + from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys @@ -72,6 +74,33 @@ def multi_class_head(n_classes, shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.multi_class_head(n_classes=3) + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.multi_class_head(n_classes=3) + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use `binary_classification_head`). @@ -139,6 +168,33 @@ def binary_classification_head( shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.binary_classification_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.binary_classification_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -211,6 +267,33 @@ def regression_head(weight_column=None, https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function Namely, for poisson regression, set `inverse_link_fn=tf.exp`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -270,6 +353,33 @@ def poisson_regression_head( This is implemented as a generalized linear model, see https://en.wikipedia.org/wiki/Generalized_linear_model. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.poisson_regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.poisson_regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -337,6 +447,33 @@ def logistic_regression_head( This is implemented as a generalized linear model, see https://en.wikipedia.org/wiki/Generalized_linear_model. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.logistic_regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.logistic_regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -375,6 +512,7 @@ def multi_label_head(n_classes, label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + classes_for_class_based_metrics=None, name=None): """Creates a `_Head` for multi-label classification. @@ -391,6 +529,7 @@ def multi_label_head(n_classes, applications, the shape is `[batch_size, n_classes]`. Labels can be: + * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. @@ -406,6 +545,33 @@ def multi_label_head(n_classes, shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.multi_label_head(n_classes=3) + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.multi_label_head(n_classes=3) + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). @@ -427,6 +593,10 @@ def multi_label_head(n_classes, reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. + classes_for_class_based_metrics: List of integer class IDs or string class + names for which per-class metrics are evaluated. If integers, all must be + in the range `[0, n_classes - 1]`. If strings, all must be in + `label_vocabulary`. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -434,8 +604,8 @@ def multi_label_head(n_classes, An instance of `_Head` for multi-label classification. Raises: - ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is - invalid. + ValueError: if `n_classes`, `thresholds`, `loss_reduction`, `loss_fn` or + `metric_class_ids` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: @@ -460,10 +630,31 @@ def multi_label_head(n_classes, if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) + classes_for_class_based_metrics = tuple( + [] if classes_for_class_based_metrics is None + else classes_for_class_based_metrics) + if classes_for_class_based_metrics: + if isinstance(classes_for_class_based_metrics[0], six.string_types): + if not label_vocabulary: + raise ValueError( + 'label_vocabulary must be provided when ' + 'classes_for_class_based_metrics are sting.') + class_ids = [] + for class_string in classes_for_class_based_metrics: + class_ids.append(label_vocabulary.index(class_string)) + classes_for_class_based_metrics = tuple(class_ids) + else: + for class_id in classes_for_class_based_metrics: + if (class_id < 0) or (class_id >= n_classes): + raise ValueError( + 'All classes_for_class_based_metrics must be in range [0, {}]. ' + 'Given: {}'.format(n_classes - 1, class_id)) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, - loss_fn=loss_fn, name=name) + loss_fn=loss_fn, + classes_for_class_based_metrics=classes_for_class_based_metrics, + name=name) class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access @@ -476,6 +667,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + classes_for_class_based_metrics=None, name=None): self._n_classes = n_classes self._weight_column = weight_column @@ -483,6 +675,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access self._label_vocabulary = label_vocabulary self._loss_reduction = loss_reduction self._loss_fn = loss_fn + self._classes_for_class_based_metrics = classes_for_class_based_metrics self._name = name @property @@ -653,6 +846,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = head_lib._append_update_ops(train_op) # pylint:disable=protected-access # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -737,4 +931,36 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weights=weights, threshold=threshold, name=recall_key)) + for class_id in self._classes_for_class_based_metrics: + batch_rank = array_ops.rank(probabilities) - 1 + begin = array_ops.concat( + [array_ops.zeros([batch_rank], dtype=dtypes.int32), [class_id]], + axis=0) + size = array_ops.concat( + [-1 * array_ops.ones([batch_rank], dtype=dtypes.int32), [1]], + axis=0) + class_probabilities = array_ops.slice( + probabilities, begin=begin, size=size) + class_labels = array_ops.slice(labels, begin=begin, size=size) + prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access + head_lib._predictions_mean( # pylint:disable=protected-access + predictions=class_probabilities, + weights=weights, + name=prob_key)) + auc_key = keys.AUC_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access + head_lib._auc( # pylint:disable=protected-access + labels=class_labels, + predictions=class_probabilities, + weights=weights, + name=auc_key)) + auc_pr_key = keys.AUC_PR_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access + head_lib._auc( # pylint:disable=protected-access + labels=class_labels, + predictions=class_probabilities, + weights=weights, + curve='PR', + name=auc_pr_key)) return metric_ops diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 19b86df5565a85168bdbc37076a0af69248a8010..b2b57fa06ba818d4455871fe57dde5ce287b39a2 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -175,6 +176,21 @@ class MultiLabelHead(test.TestCase): r'loss_fn has unexpected args: \[\'name\'\]'): head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + def test_classes_for_class_based_metrics_invalid(self): + with self.assertRaisesRegexp( + ValueError, + r'All classes_for_class_based_metrics must be in range \[0, 2\]\. ' + r'Given: -1'): + head_lib.multi_label_head( + n_classes=3, classes_for_class_based_metrics=[2, -1]) + + def test_classes_for_class_based_metrics_string_invalid(self): + with self.assertRaisesRegexp( + ValueError, r'\'z\' is not in list'): + head_lib.multi_label_head( + n_classes=3, label_vocabulary=['a', 'b', 'c'], + classes_for_class_based_metrics=['c', 'z']) + def test_name(self): head = head_lib.multi_label_head(n_classes=4, name='foo') self.assertEqual('foo', head.name) @@ -591,6 +607,81 @@ class MultiLabelHead(test.TestCase): expected_loss=expected_loss, expected_metrics=expected_metrics) + def test_eval_with_classes_for_class_based_metrics(self): + head = head_lib.multi_label_head( + n_classes=2, classes_for_class_based_metrics=[0, 1]) + + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) + + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_CLASS % 0: 0., + keys.AUC_PR_AT_CLASS % 0: 1., + keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_CLASS % 1: 1., + keys.AUC_PR_AT_CLASS % 1: 1., + } + + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + + def test_eval_with_classes_for_class_based_metrics_string(self): + head = head_lib.multi_label_head( + n_classes=2, label_vocabulary=['a', 'b'], + classes_for_class_based_metrics=['a', 'b']) + + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = sparse_tensor.SparseTensor( + values=['a', 'a', 'b'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + labels_onehot = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_onehot, logits=logits)) + + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_CLASS % 0: 0., + keys.AUC_PR_AT_CLASS % 0: 1., + keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_CLASS % 1: 1., + keys.AUC_PR_AT_CLASS % 1: 1., + } + + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + def test_eval_with_weights(self): n_classes = 2 head = head_lib.multi_label_head(n_classes, weight_column='example_weights') @@ -899,6 +990,34 @@ class MultiLabelHead(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) + def test_train_with_update_ops(self): + head = head_lib.multi_label_head(n_classes=2) + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=np.array([[1, 0], [1, 1]], dtype=np.int64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_with_regularization_losses(self): head = head_lib.multi_label_head( n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd6aa442f82bad2d4714dbcdc85b20b34773068 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Some useful session run hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import training + + +# pylint: disable=protected-access +class InMemoryEvaluatorHook(training.SessionRunHook): + """Hook to run evaluation in training without a checkpoint. + + Example: + + ```python + def train_input_fn(): + ... + return train_dataset + + def eval_input_fn(): + ... + return eval_dataset + + estimator = tf.estimator.DNNClassifier(...) + + evaluator = tf.contrib.estimator.InMemoryEvaluatorHook( + estimator, eval_input_fn) + estimator.train(train_input_fn, hooks=[evaluator]) + ``` + + Current limitations of this approach are: + * It doesn't support multi-node distributed mode. + * It doesn't support saveable objects other than variables (such as boosted + tree support) + * It doesn't support custom saver logic (such as ExponentialMovingAverage + support) + + """ + + def __init__(self, + estimator, + input_fn, + steps=None, + hooks=None, + name=None, + every_n_iter=100): + """Initializes a `InMemoryEvaluatorHook`. + + Args: + estimator: A `tf.estimator.Estimator` instance to call evaluate. + input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A + function that constructs the input data for evaluation. + See @{$premade_estimators#create_input_functions} for more + information. The function should construct and return one of + the following: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + steps: Equivalent to the `steps` arg to `estimator.evaluate`. Number of + steps for which to evaluate model. If `None`, evaluates until `input_fn` + raises an end-of-input exception. + hooks: Equivalent to the `hooks` arg to `estimator.evaluate`. List of + `SessionRunHook` subclass instances. Used for callbacks inside the + evaluation call. + name: Equivalent to the `name` arg to `estimator.evaluate`. Name of the + evaluation if user needs to run multiple evaluations on different data + sets, such as on training data vs test data. Metrics for different + evaluations are saved in separate folders, and appear separately in + tensorboard. + every_n_iter: `int`, runs the evaluator once every N training iteration. + + Raises: + ValueError: if `every_n_iter` is non-positive or it's not a single machine + training + """ + if every_n_iter is None or every_n_iter <= 0: + raise ValueError('invalid every_n_iter=%s.' % every_n_iter) + if (estimator.config.num_ps_replicas > 0 or + estimator.config.num_worker_replicas > 1): + raise ValueError( + 'InMemoryEvaluator supports only single machine (aka Local) setting.') + self._estimator = estimator + self._input_fn = input_fn + self._steps = steps + self._name = name + self._every_n_iter = every_n_iter + self._eval_dir = os.path.join(self._estimator.model_dir, 'eval' + if not name else 'eval_' + name) + + self._graph = None + self._hooks = estimator_lib._check_hooks_type(hooks) + self._hooks.extend(self._estimator._convert_eval_steps_to_hooks(steps)) + self._timer = training.SecondOrStepTimer(every_steps=every_n_iter) + + def begin(self): + """Build eval graph and restoring op.""" + self._timer.reset() + self._iter_count = 0 + self._graph = ops.Graph() + with self._graph.as_default(): + (self._scaffold, self._update_op, self._eval_dict, + self._all_hooks) = self._estimator._evaluate_build_graph( + self._input_fn, self._hooks, checkpoint_path=None) + + if self._scaffold.saver is not None: + raise ValueError('InMemoryEvaluator does not support custom saver') + if self._scaffold.init_fn is not None: + raise ValueError('InMemoryEvaluator does not support custom init_fn') + + self._var_name_to_eval_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + self._var_name_to_placeholder = { + v.name: array_ops.placeholder(v.dtype) + for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + + def after_create_session(self, session, coord): # pylint: disable=unused-argument + """Does first run which shows the eval metrics before training.""" + if ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS): + raise ValueError( + 'InMemoryEvaluator does not support saveables other than global ' + 'variables.') + self._var_name_to_train_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + var_names_to_transfer = set(self._var_name_to_placeholder.keys()) & set( + self._var_name_to_train_var.keys()) + # Filter training var names that are not exist in evaluation + self._var_name_to_train_var = { + v_name: self._var_name_to_train_var[v_name] + for v_name in var_names_to_transfer + } + # Filter eval var names that are not exist in training + self._var_name_to_eval_var = { + v_name: self._var_name_to_eval_var[v_name] + for v_name in var_names_to_transfer + } + + with self._graph.as_default(): + self._var_feed_op = control_flow_ops.group([ + state_ops.assign(self._var_name_to_eval_var[v_name], + self._var_name_to_placeholder[v_name]) + for v_name in var_names_to_transfer + ]) + + self._evaluate(session) + + def _evaluate(self, train_session): + var_name_to_value = train_session.run(self._var_name_to_train_var) + placeholder_to_value = { + self._var_name_to_placeholder[v_name]: var_name_to_value[v_name] + for v_name in var_name_to_value + } + + def feed_variables(scaffold, session): + del scaffold + session.run(self._var_feed_op, feed_dict=placeholder_to_value) + + scaffold = training.Scaffold( + init_fn=feed_variables, copy_from_scaffold=self._scaffold) + + with self._graph.as_default(): + return self._estimator._evaluate_run( + checkpoint_path=None, + scaffold=scaffold, + update_op=self._update_op, + eval_dict=self._eval_dict, + all_hooks=self._all_hooks, + output_dir=self._eval_dir) + + self._timer.update_last_triggered_step(self._iter_count) + + def after_run(self, run_context, run_values): # pylint: disable=unused-argument + """Runs evaluator.""" + self._iter_count += 1 + if self._timer.should_trigger_for_step(self._iter_count): + self._evaluate(run_context.session) + + def end(self, session): # pylint: disable=unused-argument + """Runs evaluator for final model.""" + self._evaluate(session) + + +# pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py new file mode 100644 index 0000000000000000000000000000000000000000..95ae971852ee6dffb6174fc243686721c30ef685 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -0,0 +1,318 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import json +import os + +from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.summary import summary_iterator +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import training + + +def summary_step_keyword_to_value_mapping(dir_): + writer_cache.FileWriterCache.clear() + + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + step_keyword_to_value = {} + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step not in step_keyword_to_value: + step_keyword_to_value[last_event.step] = {} + if last_event.summary is not None: + for value in last_event.summary.value: + step_keyword_to_value[last_event.step][value.tag] = value.simple_value + + return step_keyword_to_value + + +def get_summary_value(dir_, step, keyword): + """Get summary value for given step and keyword.""" + + writer_cache.FileWriterCache.clear() + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + print('XXX', event_paths) + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step == step and last_event.summary is not None: + for value in last_event.summary.value: + if keyword in value.tag: + return value.simple_value + return None + + +class InMemoryEvaluatorHookTest(test.TestCase): + + def test_runs_eval_metrics(self): + + def model_fn(features, labels, mode): + _ = labels + if estimator_lib.ModeKeys.TRAIN == mode: + with ops.control_dependencies([features]): + train_op = state_ops.assign_add(training.get_global_step(), 1) + return estimator_lib.EstimatorSpec( + mode, loss=constant_op.constant(3.), train_op=train_op) + if estimator_lib.ModeKeys.EVAL == mode: + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(5.), + eval_metric_ops={'mean_of_features': metrics_lib.mean(features)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # 4.5 = sum(range(10))/10 + # before training + self.assertEqual(4.5, step_keyword_to_value[0]['mean_of_features']) + # intervals (every_n_iter=4) + self.assertEqual(4.5, step_keyword_to_value[4]['mean_of_features']) + self.assertEqual(4.5, step_keyword_to_value[8]['mean_of_features']) + # end + self.assertEqual(4.5, step_keyword_to_value[10]['mean_of_features']) + + def test_uses_latest_variable_value(self): + + def model_fn(features, labels, mode): + _ = labels + step = training.get_global_step() + w = variable_scope.get_variable( + 'w', + shape=[], + initializer=init_ops.zeros_initializer(), + dtype=dtypes.int64) + if estimator_lib.ModeKeys.TRAIN == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + step_inc = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([step_inc]): + assign_w_to_step_plus_2 = w.assign(step + 2) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + train_op=assign_w_to_step_plus_2) + if estimator_lib.ModeKeys.EVAL == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + loss = constant_op.constant(5.) + return estimator_lib.EstimatorSpec( + mode, + loss=loss, + # w is constant in each step, so the mean. + # w = 0 if step==0 else step+2 + eval_metric_ops={'mean_of_const': metrics_lib.mean(w)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # w = 0 if step==0 else step+2 + self.assertEqual(0, step_keyword_to_value[0]['mean_of_const']) + self.assertEqual(6, step_keyword_to_value[4]['mean_of_const']) + self.assertEqual(12, step_keyword_to_value[10]['mean_of_const']) + + def test_dnn_classifier(self): + embedding = feature_column_lib.embedding_column( + feature_column_lib.categorical_column_with_vocabulary_list( + 'wire_cast', ['kima', 'omar', 'stringer']), 8) + dnn = estimator_lib.DNNClassifier( + feature_columns=[embedding], hidden_units=[3, 1]) + + def train_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['omar'], ['kima']] + }, [[0], [1]])).repeat(3) + + def eval_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['stringer'], ['kima']] + }, [[0], [1]])).repeat(2) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + dnn, eval_input_fn, name='in-memory') + dnn.train(train_input_fn, hooks=[evaluator]) + self.assertTrue(os.path.isdir(dnn.eval_dir('in-memory'))) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + dnn.eval_dir('in-memory')) + + final_metrics = dnn.evaluate(eval_input_fn) + step = final_metrics[ops.GraphKeys.GLOBAL_STEP] + for summary_tag in final_metrics: + if summary_tag == ops.GraphKeys.GLOBAL_STEP: + continue + self.assertEqual(final_metrics[summary_tag], + step_keyword_to_value[step][summary_tag]) + + def test_raise_error_with_multi_worker(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_ps(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1'], + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_custom_saver_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(saver=training.Saver()), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom saver'): + evaluator.begin() + + def test_raise_error_with_custom_init_fn_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + + def init_fn(scaffold, session): + _, _ = scaffold, session + + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_fn=init_fn), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom init_fn'): + evaluator.begin() + + def test_raise_error_with_saveables_other_than_global_variables(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + w = variables.Variable( + initial_value=[0.], + trainable=False, + collections=[ops.GraphKeys.SAVEABLE_OBJECTS]) + init_op = control_flow_ops.group( + [w.initializer, training.get_global_step().initializer]) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_op=init_op), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support saveables'): + estimator.train(input_fn, hooks=[evaluator]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns.py b/tensorflow/contrib/estimator/python/estimator/logit_fns.py index 09c2862ccd3f90de4153a2095afc9c3d3f9476c1..c8b0dd62970e341a3c6b176278fe1c2adfcd8d20 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns.py @@ -41,10 +41,10 @@ from __future__ import print_function import six -from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import dnn as dnn_core from tensorflow.python.estimator.canned import linear as linear_core from tensorflow.python.framework import ops +from tensorflow.python.util import function_utils # pylint: disable=protected-access dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder @@ -72,7 +72,7 @@ def call_logit_fn(logit_fn, features, mode, params, config): ValueError: if logit_fn does not return a Tensor or a dictionary mapping strings to Tensors. """ - logit_fn_args = util.fn_args(logit_fn) + logit_fn_args = function_utils.fn_args(logit_fn) kwargs = {} if 'mode' in logit_fn_args: kwargs['mode'] = mode diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index f8564446e5da3e785b85010998d18dca0424d16b..cda23aa437f954700b74dcb9294550eb9a8a8c5c 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -32,7 +32,6 @@ import six from tensorflow.core.framework import node_def_pb2 from tensorflow.python.client import device_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import device as framework_device from tensorflow.python.framework import ops as ops_lib @@ -48,6 +47,7 @@ from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_setter as device_setter_lib from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils @deprecation.deprecated( @@ -521,7 +521,7 @@ def _get_loss_towers(model_fn, """Replicate the loss computation across devices.""" tower_specs = [] - model_fn_args = util.fn_args(model_fn) + model_fn_args = function_utils.fn_args(model_fn) optional_params = {} if 'params' in model_fn_args: optional_params['params'] = copy.deepcopy(params) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index 7f385fd76e88aba46f45d16198d707bf1d1e0d8a..7c49cd00d16777872ad1211dfa1d1a3ac9ac1cee 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -229,6 +229,7 @@ def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns, rnn_outputs, _ = rnn.dynamic_rnn( cell=cell, inputs=sequence_input, + sequence_length=sequence_length, dtype=dtypes.float32, time_major=False) last_activations = _select_last_activations(rnn_outputs, sequence_length) diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 5cef4068ed119d5dbccd585c5b4e5e28840d2cc7..8f73274c2a0ebbdc41ce6a647a8a5650694c9a23 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -197,7 +197,8 @@ class WALSModel(object): row_weights=1, col_weights=1, use_factors_weights_cache=True, - use_gramian_cache=True): + use_gramian_cache=True, + use_scoped_vars=False): """Creates model for WALS matrix factorization. Args: @@ -239,6 +240,8 @@ class WALSModel(object): weights cache to take effect. use_gramian_cache: When True, the Gramians will be cached on the workers before the updates start. Defaults to True. + use_scoped_vars: When True, the factor and weight vars will also be nested + in a tf.name_scope. """ self._input_rows = input_rows self._input_cols = input_cols @@ -251,25 +254,46 @@ class WALSModel(object): regularization * linalg_ops.eye(self._n_components) if regularization is not None else None) assert (row_weights is None) == (col_weights is None) - self._row_weights = WALSModel._create_weights( - row_weights, self._input_rows, self._num_row_shards, "row_weights") - self._col_weights = WALSModel._create_weights( - col_weights, self._input_cols, self._num_col_shards, "col_weights") self._use_factors_weights_cache = use_factors_weights_cache self._use_gramian_cache = use_gramian_cache - self._row_factors = self._create_factors( - self._input_rows, self._n_components, self._num_row_shards, row_init, - "row_factors") - self._col_factors = self._create_factors( - self._input_cols, self._n_components, self._num_col_shards, col_init, - "col_factors") + + if use_scoped_vars: + with ops.name_scope("row_weights"): + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + with ops.name_scope("col_weights"): + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + with ops.name_scope("row_factors"): + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, + row_init, "row_factors") + with ops.name_scope("col_factors"): + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, + col_init, "col_factors") + else: + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, row_init, + "row_factors") + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, col_init, + "col_factors") + self._row_gramian = self._create_gramian(self._n_components, "row_gramian") self._col_gramian = self._create_gramian(self._n_components, "col_gramian") - self._row_update_prep_gramian = self._prepare_gramian( - self._col_factors, self._col_gramian) - self._col_update_prep_gramian = self._prepare_gramian( - self._row_factors, self._row_gramian) - self._create_transient_vars() + with ops.name_scope("row_prepare_gramian"): + self._row_update_prep_gramian = self._prepare_gramian( + self._col_factors, self._col_gramian) + with ops.name_scope("col_prepare_gramian"): + self._col_update_prep_gramian = self._prepare_gramian( + self._row_factors, self._row_gramian) + with ops.name_scope("transient_vars"): + self._create_transient_vars() @property def row_factors(self): diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 555beddeaab419bcb23d06f960d370b706d744c8..b588f75efe9d0bbf8213a89978a627c0a0ccf554 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -346,7 +346,8 @@ def sequence_numeric_column( key, shape=(1,), default_value=0., - dtype=dtypes.float32): + dtype=dtypes.float32, + normalizer_fn=None): """Returns a feature column that represents sequences of numeric data. Example: @@ -370,6 +371,12 @@ def sequence_numeric_column( default_value: A single value compatible with `dtype` that is used for padding the sparse data into a dense `Tensor`. dtype: The type of values. + normalizer_fn: If not `None`, a function that can be used to normalize the + value of the tensor after `default_value` is applied for parsing. + Normalizer function takes the input `Tensor` as its argument, and returns + the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that + even though the most common use case of this function is normalization, it + can be used for any kind of Tensorflow transformations. Returns: A `_SequenceNumericColumn`. @@ -383,12 +390,16 @@ def sequence_numeric_column( if not (dtype.is_integer or dtype.is_floating): raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) + if normalizer_fn is not None and not callable(normalizer_fn): + raise TypeError( + 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) return _SequenceNumericColumn( key, shape=shape, default_value=default_value, - dtype=dtype) + dtype=dtype, + normalizer_fn=normalizer_fn) def _assert_all_equal_and_return(tensors, name=None): @@ -407,7 +418,7 @@ class _SequenceNumericColumn( fc._SequenceDenseColumn, collections.namedtuple( '_SequenceNumericColumn', - ['key', 'shape', 'default_value', 'dtype'])): + ['key', 'shape', 'default_value', 'dtype', 'normalizer_fn'])): """Represents sequences of numeric data.""" @property @@ -419,7 +430,10 @@ class _SequenceNumericColumn( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - return inputs.get(self.key) + input_tensor = inputs.get(self.key) + if self.normalizer_fn is not None: + input_tensor = self.normalizer_fn(input_tensor) + return input_tensor @property def _variable_shape(self): diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 88f5d535162939e063eb1e7f43d495137c5adef4..89b5f4c4137f6c42417f539a578fd8b11f8b235d 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test from tensorflow.python.training import monitored_session @@ -670,6 +671,7 @@ class SequenceNumericColumnTest(test.TestCase): self.assertEqual((1,), a.shape) self.assertEqual(0., a.default_value) self.assertEqual(dtypes.float32, a.dtype) + self.assertIsNone(a.normalizer_fn) def test_shape_saved_as_tuple(self): a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) @@ -688,6 +690,10 @@ class SequenceNumericColumnTest(test.TestCase): ValueError, 'dtype must be convertible to float'): sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + def test_normalizer_fn_must_be_callable(self): + with self.assertRaisesRegexp(TypeError, 'must be a callable'): + sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable') + def test_get_sequence_dense_tensor(self): sparse_input = sparse_tensor.SparseTensorValue( # example 0, values [[0.], [1]] @@ -708,6 +714,41 @@ class SequenceNumericColumnTest(test.TestCase): self.assertAllEqual( expected_dense_tensor, dense_tensor.eval(session=sess)) + def test_get_sequence_dense_tensor_with_normalizer_fn(self): + + def _increment_two(input_sparse_tensor): + return sparse_ops.sparse_add( + input_sparse_tensor, + sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2)) + ) + + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + + # Before _increment_two: + # [[0.], [1.]], + # [[10.], [0.]], + # After _increment_two: + # [[2.], [1.]], + # [[10.], [2.]], + expected_dense_tensor = [ + [[2.], [1.]], + [[10.], [2.]], + ] + numeric_column = sfc.sequence_numeric_column( + 'aaa', normalizer_fn=_increment_two) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + def test_get_sequence_dense_tensor_with_shape(self): """Tests get_sequence_dense_tensor with shape !=(1,).""" sparse_input = sparse_tensor.SparseTensorValue( diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index daba965a98893b992abdc598ec713f13020d6e91..484ffee3e7afe55c63cab2a463454353b2663e18 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -28,7 +28,6 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio -from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 020b5c99c61019254bef0b1dff6bc5901c92758a..b1b5126d9e9e5196a1733b80e0778e53cef7f774 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -21,7 +21,6 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py -from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 10d1ecc738de6777784200ba934a521dff592e28..dc49383c5c300e82839c478e097074b3e8776b3b 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -119,14 +119,13 @@ from tensorflow.python.framework.smart_cond import smart_cond from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec -from tensorflow.python.ops.array_ops import broadcast_to from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['nest', 'broadcast_to'] +_allowed_symbols = ['nest'] _nest_allowed_symbols = [ 'assert_same_structure', 'is_sequence', diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index 8fc4f60492b0bfb22ea78cb7b5906e452bb6da58..af1b404cb51bf5d8f8350481f2301d9653895e85 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -78,7 +78,6 @@ class AssertScalarIntTest(test.TestCase): [3, 4], dtype=dtypes.int32)) -@test_util.with_c_api class WithShapeTest(test.TestCase): def _assert_with_shape(self, tensor, expected_value, expected_shape, @@ -216,25 +215,18 @@ class WithShapeTest(test.TestCase): tensor_partial_shape.set_shape([None, 2]) for incompatible_shape in [[0], [1]]: - if ops._USE_C_API: - error_message = "Shapes must be equal rank, but are 2 and 1" - else: - error_message = r"Shapes \(\?, 2\) and \([01],\) are not compatible" self.assertRaisesRegexp( - ValueError, error_message, + ValueError, "Shapes must be equal rank, but are 2 and 1", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[1, 2, 1]]: self.assertRaisesRegexp(ValueError, "Dimensions must be equal", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[2, 1]]: - if ops._USE_C_API: - error_message = (r"Dimension 1 in both shapes must be equal, but are " - r"2 and 1. Shapes are \[\?,2\] and \[2,1\].") - else: - error_message = r"Shapes \(\?, 2\) and \(2, 1\) are not compatible" self.assertRaisesRegexp( - ValueError, error_message, + ValueError, + r"Dimension 1 in both shapes must be equal, but are 2 and 1. " + r"Shapes are \[\?,2\] and \[2,1\].", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) compatible_shape = [2, 2] diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 0eb6889db1fae1c74aeb4392441b308392b091a5..0f0813c07f8bd330b089780064e02f8dfe7d49f6 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -75,6 +75,7 @@ tf_kernel_library( "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", "//third_party/eigen3", + "@local_config_cuda//cuda:cudnn_header", ], alwayslink = 1, ) @@ -94,6 +95,7 @@ tf_custom_op_library( "//tensorflow/core/kernels:conv_ops_gpu_hdrs", "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", + "@local_config_cuda//cuda:cudnn_header", ], ) diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 3d0ed899322c26bf4ae428930899d7a5885e9f21..4d62ac65ff619f98a18387058fdc8a0eade0d8f8 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -289,8 +289,8 @@ class FusedConv2DBiasActivationTest(test.TestCase): conv = tensors[i] value = values[i] ref_value = ref_values[i] - print("expected = ", ref_value) - print("actual = ", value) + tf_logging.info("expected = ", ref_value) + tf_logging.info("actual = ", value) tol = 1e-5 if value.dtype == np.float16: tol = 1e-3 @@ -831,7 +831,8 @@ class FusedConvInt8Tests(test.TestCase): vertical_stride, padding_type) output_width = CalculateConvolvedOutputDim(input_width, filter_width, horizontal_stride, padding_type) - print("output_height=", output_height, ", output_width=", output_width) + tf_logging.info("output_height=", output_height, ", output_width=", + output_width) side_input, _, _ = gen_array_ops.quantize_v2( random_ops.random_uniform( @@ -866,8 +867,8 @@ class FusedConvInt8Tests(test.TestCase): with self.test_session(use_gpu=True) as sess: actual_y, expected_y = sess.run([actual, expected]) - print("actual_y = ", actual_y) - print("expected_y = ", expected_y) + tf_logging.info("actual_y = ", actual_y) + tf_logging.info("expected_y = ", expected_y) self.assertTrue(np.array_equal(actual_y, expected_y)) def testFusedConvInt8(self): diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 2889e937436d2faa66b5693c19046e122cbaf652..9f5fee45422e0b9bcbc73674e55ae395ea8533d5 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -570,7 +570,7 @@ class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): 'predicted_distributions': self._predicted_distributions, } self._expected_loss = 1.61610 - self._expected_op_name = 'mutual_information_loss/mul' + self._expected_op_name = 'mutual_information_loss/mul_1' self._batch_size = 2 diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index 1f9dd0decb84cf9b7b703f18c061d3c0c7a1cb25..9025c992a4467f521d6d8d514e6a5e92f5492947 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -57,7 +57,7 @@ Status GdrServer::Init() { new GdrWorker(env, remote_memory_manager_.get())); }; TF_RETURN_IF_ERROR( - GrpcServer::Init(nullptr, rendezvous_mgr_func, worker_func)); + GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func)); return remote_memory_manager_->Init(); } diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 592d37b432ee605d74162e0b8ec6ccdf426c45d1..026a3d1200033400472c4fd763a244c04b284a9b 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -189,9 +189,6 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None): if op._original_op: op_._original_op = op._original_op - # Add op to the graph - info.graph_._add_op(op_) - return op_, op_.outputs @@ -492,7 +489,7 @@ class Transformer(object): t_ = info.transformed_ts[t] consumer_op_ = info.transformed_ops[consumer_op] t_index_ = list(consumer_op_.inputs).index(tmp_t_) - consumer_op_._update_input(t_index_, t_, update_dtype=False) # pylint: disable=protected-access + consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access def _connect_control_inputs(self, info): """Connect the previously copied ops.""" diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c index 6a5d982dc8514d69277b8f042ac1256e28715d9e..2e5c84704f8464ab46d740ea3c1eef0548826e8d 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c +++ b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c @@ -19,7 +19,7 @@ limitations under the License. #include "hexagon_controller.h" -#include +#include #include #include "adspmsgd.h" diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py index b4a99867ed46897f60be3f230838c3f576d5455e..61f78febfc07bb4e677259366a81c16b2b585244 100644 --- a/tensorflow/contrib/integrate/python/ops/odes.py +++ b/tensorflow/contrib/integrate/python/ops/odes.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops @@ -279,13 +278,27 @@ def _assert_increasing(t): return ops.control_dependencies([assert_increasing]) -def _check_input_types(t, y0): +def _check_input_types(y0, t, dt=None): if not (y0.dtype.is_floating or y0.dtype.is_complex): raise TypeError('`y0` must have a floating point or complex floating ' 'point dtype') if not t.dtype.is_floating: raise TypeError('`t` must have a floating point dtype') + if dt is not None and not dt.dtype.is_floating: + raise TypeError('`dt` must have a floating point dtype') + + +def _check_input_sizes(t, dt): + if len(t.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if len(dt.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if t.get_shape()[0] != dt.get_shape()[0] + 1: + raise ValueError('t and dt have incompatible lengths, must be N and N-1') + def _dopri5(func, y0, @@ -510,7 +523,7 @@ def odeint(func, # avoiding the need to pack/unpack in user functions. y0 = ops.convert_to_tensor(y0, name='y0') t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') - _check_input_types(t, y0) + _check_input_types(y0, t) error_dtype = abs(y0).dtype rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol') @@ -530,24 +543,74 @@ def odeint(func, class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): """Base class for fixed-grid ODE integrators.""" - def integrate(self, evol_func, y0, time_grid): - time_delta_grid = time_grid[1:] - time_grid[:-1] - - scan_func = self._make_scan_func(evol_func) + def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals): + """Returns integrated values of differential equation on the `time grid`. + + Numerically integrates differential equation defined via time derivative + evaluator `evol_func` using fixed time steps specified in dt_grid. + + Args: + evol_func: Callable, evaluates time derivative of y at a given time. + y0: N-D Tensor holds initial values of the solution. + time_grid: 1-D Tensor holding the time points at which the solution + will be recorded, must have a floating dtype. + dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid + intervals. Must be a floating dtype and have one less element than that + of the time_grid. + steps_on_intervals: 1-D Tensor of integer dtype, must have the same size + as dt_grid. Specifies number of steps needed for every interval. Assumes + steps_on_intervals * dt_grid == time intervals. + + Returns: + (N+1)-D tensor, where the first dimension corresponds to different + time points. Contains the solved value of y for each desired time point in + `t`, with the initial value `y0` being the first element along the first + dimension. + """ - y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid), - y0) - return array_ops.concat([[y0], y_grid], axis=0) + iteration_func = self._make_iteration_func(evol_func, dt_grid) + integrate_interval = self._make_interval_integrator(iteration_func, + steps_on_intervals) - def _make_scan_func(self, evol_func): + num_times = array_ops.size(time_grid) + current_time = time_grid[0] + solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times) + solution_array = solution_array.write(0, y0) - def scan_func(y, t_and_dt): - t, dt = t_and_dt + solution_array, _, _, _ = control_flow_ops.while_loop( + lambda _, __, ___, i: i < num_times, + integrate_interval, + (solution_array, y0, current_time, 1) + ) + solution_array = solution_array.stack() + solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape())) + return solution_array + + def _make_iteration_func(self, evol_func, dt_grid): + """Returns a function that builds operations of a single time step.""" + + def iteration_func(y, t, dt_step, interval_step): + """Performs a single time step advance.""" + dt = dt_grid[interval_step - 1] dy = self._step_func(evol_func, t, dt, y) dy = math_ops.cast(dy, dtype=y.dtype) - return y + dy + return y + dy, t + dt, dt_step + 1, interval_step + + return iteration_func + + def _make_interval_integrator(self, iteration_func, interval_sizes): + """Returns a function that builds operations for interval integration.""" - return scan_func + def integrate_interval(solution_array, y, t, interval_num): + """Integrates y with fixed time step on interval `interval_num`.""" + y, t, _, _ = control_flow_ops.while_loop( + lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1], + iteration_func, + (y, t, 0, interval_num) + ) + return solution_array.write(interval_num, y), y, t, interval_num + 1 + + return integrate_interval @abc.abstractmethod def _step_func(self, evol_func, t, dt, y): @@ -555,6 +618,7 @@ class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): class _MidpointFixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing midpoint scheme.""" def _step_func(self, evol_func, t, dt, y): dt_cast = math_ops.cast(dt, y.dtype) @@ -563,6 +627,7 @@ class _MidpointFixedGridIntegrator(_FixedGridIntegrator): class _RK4FixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing RK4 scheme.""" def _step_func(self, evol_func, t, dt, y): k1 = evol_func(y, t) @@ -575,7 +640,7 @@ class _RK4FixedGridIntegrator(_FixedGridIntegrator): return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6) -def odeint_fixed(func, y0, t, method='rk4', name=None): +def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None): """ODE integration on a fixed grid (with no step size control). Useful in certain scenarios to avoid the overhead of adaptive step size @@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): `y`. The initial time point should be the first element of this sequence, and each time must be larger than the previous time. May have any floating point dtype. + dt: 0-D or 1-D Tensor providing time step suggestion to be used on time + integration intervals in `t`. 1-D Tensor should provide values + for all intervals, must have 1 less element than that of `t`. + If given a 0-D Tensor, the value is interpreted as time step suggestion + same for all intervals. If passed None, then time step is set to be the + t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by + insuring an integer number of steps per interval, potentially reducing the + time step. method: One of 'midpoint' or 'rk4'. name: Optional name for the resulting operation. @@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): Raises: ValueError: Upon caller errors. """ - with ops.name_scope(name, 'odeint_fixed', [y0, t]): + with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]): t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') y0 = ops.convert_to_tensor(y0, name='y0') - _check_input_types(t, y0) + + intervals = t[1:] - t[:-1] + if dt is None: + dt = intervals + dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt') + + steps_on_intervals = math_ops.ceil(intervals / dt) + dt = intervals / steps_on_intervals + steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32) + + _check_input_types(y0, t, dt) + _check_input_sizes(t, dt) with _assert_increasing(t): with ops.name_scope(method): if method == 'midpoint': - return _MidpointFixedGridIntegrator().integrate(func, y0, t) + return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) elif method == 'rk4': - return _RK4FixedGridIntegrator().integrate(func, y0, t) + return _RK4FixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) else: raise ValueError('method not supported: {!s}'.format(method)) diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py index 3ec01212d25ca8dc6e13f340177a5e85138868d5..c7b4e2faa84e1a87cb1904b22eb0008ab1ee4be6 100644 --- a/tensorflow/contrib/integrate/python/ops/odes_test.py +++ b/tensorflow/contrib/integrate/python/ops/odes_test.py @@ -242,40 +242,56 @@ class InterpolationTest(test.TestCase): class OdeIntFixedTest(test.TestCase): - def _test_integrate_sine(self, method): + def _test_integrate_sine(self, method, t, dt=None): def evol_func(y, t): del t return array_ops.stack([y[1], -y[0]]) y0 = [0., 1.] - time_grid = np.linspace(0., 10., 200) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.sin(t), rtol=1e-2, atol=1e-2) - def _test_integrate_gaussian(self, method): + def _test_integrate_gaussian(self, method, t, dt=None): def evol_func(y, t): return -math_ops.cast(t, dtype=y.dtype) * y[0] y0 = [1.] - time_grid = np.linspace(0., 2., 100) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.exp(-t**2 / 2), rtol=1e-2, atol=1e-2) + + def _test_integrate_sine_all(self, method): + uniform_time_grid = np.linspace(0., 10., 200) + non_uniform_time_grid = np.asarray([0.0, 0.4, 4.7, 5.2, 7.0]) + uniform_dt = 0.02 + non_uniform_dt = np.asarray([0.01, 0.001, 0.05, 0.03]) + self._test_integrate_sine(method, uniform_time_grid) + self._test_integrate_sine(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_sine(method, non_uniform_time_grid, non_uniform_dt) + + def _test_integrate_gaussian_all(self, method): + uniform_time_grid = np.linspace(0., 2., 100) + non_uniform_time_grid = np.asarray([0.0, 0.1, 0.7, 1.2, 2.0]) + uniform_dt = 0.01 + non_uniform_dt = np.asarray([0.01, 0.001, 0.1, 0.03]) + self._test_integrate_gaussian(method, uniform_time_grid) + self._test_integrate_gaussian(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_gaussian(method, non_uniform_time_grid, non_uniform_dt) def _test_everything(self, method): - self._test_integrate_sine(method) - self._test_integrate_gaussian(method) + self._test_integrate_sine_all(method) + self._test_integrate_gaussian_all(method) def test_midpoint(self): self._test_everything('midpoint') @@ -283,6 +299,21 @@ class OdeIntFixedTest(test.TestCase): def test_rk4(self): self._test_everything('rk4') + def test_dt_size_exceptions(self): + times = np.linspace(0., 2., 100) + dt = np.ones(99) * 0.01 + dt_wrong_length = np.asarray([0.01, 0.001, 0.1, 0.03]) + dt_wrong_dim = np.expand_dims(np.linspace(0., 2., 99), axis=0) + times_wrong_dim = np.expand_dims(np.linspace(0., 2., 100), axis=0) + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_length) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_dim) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times_wrong_dim, dt) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index a4cd4a2cc4b99b5906185bd2b942ed15c1ddf5e4..2638b25ec424b5b4ef556ff769e94e64da32fec2 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -64,7 +64,7 @@ class KafkaDatasetOp : public DatasetOpKernel { eof_(eof), timeout_(timeout) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Kafka")})); @@ -81,7 +81,7 @@ class KafkaDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { return "KafkaDatasetOp::Dataset"; } + string DebugString() const override { return "KafkaDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, diff --git a/tensorflow/contrib/keras/api/keras/activations/__init__.py b/tensorflow/contrib/keras/api/keras/activations/__init__.py index d04838c218d6643a703723a1d163c88547c14da7..3f0184276f6b903be63f7b35459e4ad57044eb2c 100644 --- a/tensorflow/contrib/keras/api/keras/activations/__init__.py +++ b/tensorflow/contrib/keras/api/keras/activations/__init__.py @@ -19,22 +19,22 @@ from __future__ import division from __future__ import print_function # Activation functions. -from tensorflow.python.keras._impl.keras.activations import elu -from tensorflow.python.keras._impl.keras.activations import hard_sigmoid -from tensorflow.python.keras._impl.keras.activations import linear -from tensorflow.python.keras._impl.keras.activations import relu -from tensorflow.python.keras._impl.keras.activations import selu -from tensorflow.python.keras._impl.keras.activations import sigmoid -from tensorflow.python.keras._impl.keras.activations import softmax -from tensorflow.python.keras._impl.keras.activations import softplus -from tensorflow.python.keras._impl.keras.activations import softsign -from tensorflow.python.keras._impl.keras.activations import tanh +from tensorflow.python.keras.activations import elu +from tensorflow.python.keras.activations import hard_sigmoid +from tensorflow.python.keras.activations import linear +from tensorflow.python.keras.activations import relu +from tensorflow.python.keras.activations import selu +from tensorflow.python.keras.activations import sigmoid +from tensorflow.python.keras.activations import softmax +from tensorflow.python.keras.activations import softplus +from tensorflow.python.keras.activations import softsign +from tensorflow.python.keras.activations import tanh # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.activations import deserialize -from tensorflow.python.keras._impl.keras.activations import serialize -from tensorflow.python.keras._impl.keras.activations import get +from tensorflow.python.keras.activations import deserialize +from tensorflow.python.keras.activations import serialize +from tensorflow.python.keras.activations import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py index abf8393ae45d71dc0cb746706abb72f77b82d199..6dfb5cab17c088bfab8ed806adeabd793ced4d12 100644 --- a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.inception_v3 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3 -from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input +from tensorflow.python.keras.applications.inception_v3 import decode_predictions +from tensorflow.python.keras.applications.inception_v3 import InceptionV3 +from tensorflow.python.keras.applications.inception_v3 import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py index b809e91193b459a46906443796344c092e1d2a6b..67306cc51e1927cfbc2db424b1f4165dabfa22f9 100644 --- a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.mobilenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet -from tensorflow.python.keras._impl.keras.applications.mobilenet import preprocess_input +from tensorflow.python.keras.applications.mobilenet import decode_predictions +from tensorflow.python.keras.applications.mobilenet import MobileNet +from tensorflow.python.keras.applications.mobilenet import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py index 530805d150bfe32c5b81d7d7d3f92e203b83b602..a25ff48b593a9a9ea56fd427a932bb64c10f7b7b 100644 --- a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.resnet50 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.resnet50 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50 +from tensorflow.python.keras.applications.resnet50 import decode_predictions +from tensorflow.python.keras.applications.resnet50 import preprocess_input +from tensorflow.python.keras.applications.resnet50 import ResNet50 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py index 118361604bbc7e0a88ed34243c0d5ea98856a301..4964b1b7deb56fe0025e9a8d8cb45d18e0209fea 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg16 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg16 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16 +from tensorflow.python.keras.applications.vgg16 import decode_predictions +from tensorflow.python.keras.applications.vgg16 import preprocess_input +from tensorflow.python.keras.applications.vgg16 import VGG16 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py index cda52628f3c10d65fdbe70b2f86cc12c771870a9..afb3abebdd6735e6f17bc94c1fcd15a31b74f983 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg19 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg19 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19 +from tensorflow.python.keras.applications.vgg19 import decode_predictions +from tensorflow.python.keras.applications.vgg19 import preprocess_input +from tensorflow.python.keras.applications.vgg19 import VGG19 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py index ae9cd9cd18c5ccc5ec37c8cd1bf36f8aabd9929c..2e3335d02aff0fff805fc2dac614b14e0593d40d 100644 --- a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.xception import decode_predictions -from tensorflow.python.keras._impl.keras.applications.xception import preprocess_input -from tensorflow.python.keras._impl.keras.applications.xception import Xception +from tensorflow.python.keras.applications.xception import decode_predictions +from tensorflow.python.keras.applications.xception import preprocess_input +from tensorflow.python.keras.applications.xception import Xception del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/backend/__init__.py b/tensorflow/contrib/keras/api/keras/backend/__init__.py index 10ef5a75852deb6595bced2703d7c5f29b0efac3..a755364014206e92289eec0b9c8e510251862e0e 100644 --- a/tensorflow/contrib/keras/api/keras/backend/__init__.py +++ b/tensorflow/contrib/keras/api/keras/backend/__init__.py @@ -19,144 +19,144 @@ from __future__ import division from __future__ import print_function # pylint: disable=redefined-builtin -from tensorflow.python.keras._impl.keras.backend import abs -from tensorflow.python.keras._impl.keras.backend import all -from tensorflow.python.keras._impl.keras.backend import any -from tensorflow.python.keras._impl.keras.backend import arange -from tensorflow.python.keras._impl.keras.backend import argmax -from tensorflow.python.keras._impl.keras.backend import argmin -from tensorflow.python.keras._impl.keras.backend import backend -from tensorflow.python.keras._impl.keras.backend import batch_dot -from tensorflow.python.keras._impl.keras.backend import batch_flatten -from tensorflow.python.keras._impl.keras.backend import batch_get_value -from tensorflow.python.keras._impl.keras.backend import batch_normalization -from tensorflow.python.keras._impl.keras.backend import batch_set_value -from tensorflow.python.keras._impl.keras.backend import bias_add -from tensorflow.python.keras._impl.keras.backend import binary_crossentropy -from tensorflow.python.keras._impl.keras.backend import cast -from tensorflow.python.keras._impl.keras.backend import cast_to_floatx -from tensorflow.python.keras._impl.keras.backend import categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import clear_session -from tensorflow.python.keras._impl.keras.backend import clip -from tensorflow.python.keras._impl.keras.backend import concatenate -from tensorflow.python.keras._impl.keras.backend import constant -from tensorflow.python.keras._impl.keras.backend import conv1d -from tensorflow.python.keras._impl.keras.backend import conv2d -from tensorflow.python.keras._impl.keras.backend import conv2d_transpose -from tensorflow.python.keras._impl.keras.backend import conv3d -from tensorflow.python.keras._impl.keras.backend import cos -from tensorflow.python.keras._impl.keras.backend import count_params -from tensorflow.python.keras._impl.keras.backend import ctc_batch_cost -from tensorflow.python.keras._impl.keras.backend import ctc_decode -from tensorflow.python.keras._impl.keras.backend import ctc_label_dense_to_sparse -from tensorflow.python.keras._impl.keras.backend import dot -from tensorflow.python.keras._impl.keras.backend import dropout -from tensorflow.python.keras._impl.keras.backend import dtype -from tensorflow.python.keras._impl.keras.backend import elu -from tensorflow.python.keras._impl.keras.backend import epsilon -from tensorflow.python.keras._impl.keras.backend import equal -from tensorflow.python.keras._impl.keras.backend import eval -from tensorflow.python.keras._impl.keras.backend import exp -from tensorflow.python.keras._impl.keras.backend import expand_dims -from tensorflow.python.keras._impl.keras.backend import eye -from tensorflow.python.keras._impl.keras.backend import flatten -from tensorflow.python.keras._impl.keras.backend import floatx -from tensorflow.python.keras._impl.keras.backend import foldl -from tensorflow.python.keras._impl.keras.backend import foldr -from tensorflow.python.keras._impl.keras.backend import function -from tensorflow.python.keras._impl.keras.backend import gather -from tensorflow.python.keras._impl.keras.backend import get_session -from tensorflow.python.keras._impl.keras.backend import get_uid -from tensorflow.python.keras._impl.keras.backend import get_value -from tensorflow.python.keras._impl.keras.backend import gradients -from tensorflow.python.keras._impl.keras.backend import greater -from tensorflow.python.keras._impl.keras.backend import greater_equal -from tensorflow.python.keras._impl.keras.backend import hard_sigmoid -from tensorflow.python.keras._impl.keras.backend import image_data_format -from tensorflow.python.keras._impl.keras.backend import in_test_phase -from tensorflow.python.keras._impl.keras.backend import in_top_k -from tensorflow.python.keras._impl.keras.backend import in_train_phase -from tensorflow.python.keras._impl.keras.backend import int_shape -from tensorflow.python.keras._impl.keras.backend import is_sparse -from tensorflow.python.keras._impl.keras.backend import l2_normalize -from tensorflow.python.keras._impl.keras.backend import learning_phase -from tensorflow.python.keras._impl.keras.backend import less -from tensorflow.python.keras._impl.keras.backend import less_equal -from tensorflow.python.keras._impl.keras.backend import log -from tensorflow.python.keras._impl.keras.backend import manual_variable_initialization -from tensorflow.python.keras._impl.keras.backend import map_fn -from tensorflow.python.keras._impl.keras.backend import max -from tensorflow.python.keras._impl.keras.backend import maximum -from tensorflow.python.keras._impl.keras.backend import mean -from tensorflow.python.keras._impl.keras.backend import min -from tensorflow.python.keras._impl.keras.backend import minimum -from tensorflow.python.keras._impl.keras.backend import moving_average_update -from tensorflow.python.keras._impl.keras.backend import name_scope -from tensorflow.python.keras._impl.keras.backend import ndim -from tensorflow.python.keras._impl.keras.backend import normalize_batch_in_training -from tensorflow.python.keras._impl.keras.backend import not_equal -from tensorflow.python.keras._impl.keras.backend import one_hot -from tensorflow.python.keras._impl.keras.backend import ones -from tensorflow.python.keras._impl.keras.backend import ones_like -from tensorflow.python.keras._impl.keras.backend import permute_dimensions -from tensorflow.python.keras._impl.keras.backend import placeholder -from tensorflow.python.keras._impl.keras.backend import pool2d -from tensorflow.python.keras._impl.keras.backend import pool3d -from tensorflow.python.keras._impl.keras.backend import pow -from tensorflow.python.keras._impl.keras.backend import print_tensor -from tensorflow.python.keras._impl.keras.backend import prod -from tensorflow.python.keras._impl.keras.backend import random_binomial -from tensorflow.python.keras._impl.keras.backend import random_normal -from tensorflow.python.keras._impl.keras.backend import random_normal_variable -from tensorflow.python.keras._impl.keras.backend import random_uniform -from tensorflow.python.keras._impl.keras.backend import random_uniform_variable -from tensorflow.python.keras._impl.keras.backend import relu -from tensorflow.python.keras._impl.keras.backend import repeat -from tensorflow.python.keras._impl.keras.backend import repeat_elements -from tensorflow.python.keras._impl.keras.backend import reset_uids -from tensorflow.python.keras._impl.keras.backend import reshape -from tensorflow.python.keras._impl.keras.backend import resize_images -from tensorflow.python.keras._impl.keras.backend import resize_volumes -from tensorflow.python.keras._impl.keras.backend import reverse -from tensorflow.python.keras._impl.keras.backend import rnn -from tensorflow.python.keras._impl.keras.backend import round -from tensorflow.python.keras._impl.keras.backend import separable_conv2d -from tensorflow.python.keras._impl.keras.backend import set_epsilon -from tensorflow.python.keras._impl.keras.backend import set_floatx -from tensorflow.python.keras._impl.keras.backend import set_image_data_format -from tensorflow.python.keras._impl.keras.backend import set_learning_phase -from tensorflow.python.keras._impl.keras.backend import set_session -from tensorflow.python.keras._impl.keras.backend import set_value -from tensorflow.python.keras._impl.keras.backend import shape -from tensorflow.python.keras._impl.keras.backend import sigmoid -from tensorflow.python.keras._impl.keras.backend import sign -from tensorflow.python.keras._impl.keras.backend import sin -from tensorflow.python.keras._impl.keras.backend import softmax -from tensorflow.python.keras._impl.keras.backend import softplus -from tensorflow.python.keras._impl.keras.backend import softsign -from tensorflow.python.keras._impl.keras.backend import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import spatial_2d_padding -from tensorflow.python.keras._impl.keras.backend import spatial_3d_padding -from tensorflow.python.keras._impl.keras.backend import sqrt -from tensorflow.python.keras._impl.keras.backend import square -from tensorflow.python.keras._impl.keras.backend import squeeze -from tensorflow.python.keras._impl.keras.backend import stack -from tensorflow.python.keras._impl.keras.backend import std -from tensorflow.python.keras._impl.keras.backend import stop_gradient -from tensorflow.python.keras._impl.keras.backend import sum -from tensorflow.python.keras._impl.keras.backend import switch -from tensorflow.python.keras._impl.keras.backend import tanh -from tensorflow.python.keras._impl.keras.backend import temporal_padding -from tensorflow.python.keras._impl.keras.backend import to_dense -from tensorflow.python.keras._impl.keras.backend import transpose -from tensorflow.python.keras._impl.keras.backend import truncated_normal -from tensorflow.python.keras._impl.keras.backend import update -from tensorflow.python.keras._impl.keras.backend import update_add -from tensorflow.python.keras._impl.keras.backend import update_sub -from tensorflow.python.keras._impl.keras.backend import var -from tensorflow.python.keras._impl.keras.backend import variable -from tensorflow.python.keras._impl.keras.backend import zeros -from tensorflow.python.keras._impl.keras.backend import zeros_like +from tensorflow.python.keras.backend import abs +from tensorflow.python.keras.backend import all +from tensorflow.python.keras.backend import any +from tensorflow.python.keras.backend import arange +from tensorflow.python.keras.backend import argmax +from tensorflow.python.keras.backend import argmin +from tensorflow.python.keras.backend import backend +from tensorflow.python.keras.backend import batch_dot +from tensorflow.python.keras.backend import batch_flatten +from tensorflow.python.keras.backend import batch_get_value +from tensorflow.python.keras.backend import batch_normalization +from tensorflow.python.keras.backend import batch_set_value +from tensorflow.python.keras.backend import bias_add +from tensorflow.python.keras.backend import binary_crossentropy +from tensorflow.python.keras.backend import cast +from tensorflow.python.keras.backend import cast_to_floatx +from tensorflow.python.keras.backend import categorical_crossentropy +from tensorflow.python.keras.backend import clear_session +from tensorflow.python.keras.backend import clip +from tensorflow.python.keras.backend import concatenate +from tensorflow.python.keras.backend import constant +from tensorflow.python.keras.backend import conv1d +from tensorflow.python.keras.backend import conv2d +from tensorflow.python.keras.backend import conv2d_transpose +from tensorflow.python.keras.backend import conv3d +from tensorflow.python.keras.backend import cos +from tensorflow.python.keras.backend import count_params +from tensorflow.python.keras.backend import ctc_batch_cost +from tensorflow.python.keras.backend import ctc_decode +from tensorflow.python.keras.backend import ctc_label_dense_to_sparse +from tensorflow.python.keras.backend import dot +from tensorflow.python.keras.backend import dropout +from tensorflow.python.keras.backend import dtype +from tensorflow.python.keras.backend import elu +from tensorflow.python.keras.backend import epsilon +from tensorflow.python.keras.backend import equal +from tensorflow.python.keras.backend import eval +from tensorflow.python.keras.backend import exp +from tensorflow.python.keras.backend import expand_dims +from tensorflow.python.keras.backend import eye +from tensorflow.python.keras.backend import flatten +from tensorflow.python.keras.backend import floatx +from tensorflow.python.keras.backend import foldl +from tensorflow.python.keras.backend import foldr +from tensorflow.python.keras.backend import function +from tensorflow.python.keras.backend import gather +from tensorflow.python.keras.backend import get_session +from tensorflow.python.keras.backend import get_uid +from tensorflow.python.keras.backend import get_value +from tensorflow.python.keras.backend import gradients +from tensorflow.python.keras.backend import greater +from tensorflow.python.keras.backend import greater_equal +from tensorflow.python.keras.backend import hard_sigmoid +from tensorflow.python.keras.backend import image_data_format +from tensorflow.python.keras.backend import in_test_phase +from tensorflow.python.keras.backend import in_top_k +from tensorflow.python.keras.backend import in_train_phase +from tensorflow.python.keras.backend import int_shape +from tensorflow.python.keras.backend import is_sparse +from tensorflow.python.keras.backend import l2_normalize +from tensorflow.python.keras.backend import learning_phase +from tensorflow.python.keras.backend import less +from tensorflow.python.keras.backend import less_equal +from tensorflow.python.keras.backend import log +from tensorflow.python.keras.backend import manual_variable_initialization +from tensorflow.python.keras.backend import map_fn +from tensorflow.python.keras.backend import max +from tensorflow.python.keras.backend import maximum +from tensorflow.python.keras.backend import mean +from tensorflow.python.keras.backend import min +from tensorflow.python.keras.backend import minimum +from tensorflow.python.keras.backend import moving_average_update +from tensorflow.python.keras.backend import name_scope +from tensorflow.python.keras.backend import ndim +from tensorflow.python.keras.backend import normalize_batch_in_training +from tensorflow.python.keras.backend import not_equal +from tensorflow.python.keras.backend import one_hot +from tensorflow.python.keras.backend import ones +from tensorflow.python.keras.backend import ones_like +from tensorflow.python.keras.backend import permute_dimensions +from tensorflow.python.keras.backend import placeholder +from tensorflow.python.keras.backend import pool2d +from tensorflow.python.keras.backend import pool3d +from tensorflow.python.keras.backend import pow +from tensorflow.python.keras.backend import print_tensor +from tensorflow.python.keras.backend import prod +from tensorflow.python.keras.backend import random_binomial +from tensorflow.python.keras.backend import random_normal +from tensorflow.python.keras.backend import random_normal_variable +from tensorflow.python.keras.backend import random_uniform +from tensorflow.python.keras.backend import random_uniform_variable +from tensorflow.python.keras.backend import relu +from tensorflow.python.keras.backend import repeat +from tensorflow.python.keras.backend import repeat_elements +from tensorflow.python.keras.backend import reset_uids +from tensorflow.python.keras.backend import reshape +from tensorflow.python.keras.backend import resize_images +from tensorflow.python.keras.backend import resize_volumes +from tensorflow.python.keras.backend import reverse +from tensorflow.python.keras.backend import rnn +from tensorflow.python.keras.backend import round +from tensorflow.python.keras.backend import separable_conv2d +from tensorflow.python.keras.backend import set_epsilon +from tensorflow.python.keras.backend import set_floatx +from tensorflow.python.keras.backend import set_image_data_format +from tensorflow.python.keras.backend import set_learning_phase +from tensorflow.python.keras.backend import set_session +from tensorflow.python.keras.backend import set_value +from tensorflow.python.keras.backend import shape +from tensorflow.python.keras.backend import sigmoid +from tensorflow.python.keras.backend import sign +from tensorflow.python.keras.backend import sin +from tensorflow.python.keras.backend import softmax +from tensorflow.python.keras.backend import softplus +from tensorflow.python.keras.backend import softsign +from tensorflow.python.keras.backend import sparse_categorical_crossentropy +from tensorflow.python.keras.backend import spatial_2d_padding +from tensorflow.python.keras.backend import spatial_3d_padding +from tensorflow.python.keras.backend import sqrt +from tensorflow.python.keras.backend import square +from tensorflow.python.keras.backend import squeeze +from tensorflow.python.keras.backend import stack +from tensorflow.python.keras.backend import std +from tensorflow.python.keras.backend import stop_gradient +from tensorflow.python.keras.backend import sum +from tensorflow.python.keras.backend import switch +from tensorflow.python.keras.backend import tanh +from tensorflow.python.keras.backend import temporal_padding +from tensorflow.python.keras.backend import to_dense +from tensorflow.python.keras.backend import transpose +from tensorflow.python.keras.backend import truncated_normal +from tensorflow.python.keras.backend import update +from tensorflow.python.keras.backend import update_add +from tensorflow.python.keras.backend import update_sub +from tensorflow.python.keras.backend import var +from tensorflow.python.keras.backend import variable +from tensorflow.python.keras.backend import zeros +from tensorflow.python.keras.backend import zeros_like del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py index 2d884790ddb9ccf49649c6af4cfd40cddbc38cb3..10e05f2969bc404d4cf3a9b7a999510cd40e3c17 100644 --- a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py +++ b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py @@ -18,19 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.callbacks import BaseLogger -from tensorflow.python.keras._impl.keras.callbacks import Callback -from tensorflow.python.keras._impl.keras.callbacks import CSVLogger -from tensorflow.python.keras._impl.keras.callbacks import EarlyStopping -from tensorflow.python.keras._impl.keras.callbacks import History -from tensorflow.python.keras._impl.keras.callbacks import LambdaCallback -from tensorflow.python.keras._impl.keras.callbacks import LearningRateScheduler -from tensorflow.python.keras._impl.keras.callbacks import ModelCheckpoint -from tensorflow.python.keras._impl.keras.callbacks import ProgbarLogger -from tensorflow.python.keras._impl.keras.callbacks import ReduceLROnPlateau -from tensorflow.python.keras._impl.keras.callbacks import RemoteMonitor -from tensorflow.python.keras._impl.keras.callbacks import TensorBoard -from tensorflow.python.keras._impl.keras.callbacks import TerminateOnNaN +from tensorflow.python.keras.callbacks import BaseLogger +from tensorflow.python.keras.callbacks import Callback +from tensorflow.python.keras.callbacks import CSVLogger +from tensorflow.python.keras.callbacks import EarlyStopping +from tensorflow.python.keras.callbacks import History +from tensorflow.python.keras.callbacks import LambdaCallback +from tensorflow.python.keras.callbacks import LearningRateScheduler +from tensorflow.python.keras.callbacks import ModelCheckpoint +from tensorflow.python.keras.callbacks import ProgbarLogger +from tensorflow.python.keras.callbacks import ReduceLROnPlateau +from tensorflow.python.keras.callbacks import RemoteMonitor +from tensorflow.python.keras.callbacks import TensorBoard +from tensorflow.python.keras.callbacks import TerminateOnNaN del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/constraints/__init__.py b/tensorflow/contrib/keras/api/keras/constraints/__init__.py index 152606d8ebbcadf57d971d508e15283da65e4aa3..08debf974ec3a36174c353ecaf9e425a9afc3f36 100644 --- a/tensorflow/contrib/keras/api/keras/constraints/__init__.py +++ b/tensorflow/contrib/keras/api/keras/constraints/__init__.py @@ -19,21 +19,21 @@ from __future__ import division from __future__ import print_function # Constraints functions / callable classes. -from tensorflow.python.keras._impl.keras.constraints import Constraint -from tensorflow.python.keras._impl.keras.constraints import max_norm -from tensorflow.python.keras._impl.keras.constraints import MaxNorm -from tensorflow.python.keras._impl.keras.constraints import min_max_norm -from tensorflow.python.keras._impl.keras.constraints import MinMaxNorm -from tensorflow.python.keras._impl.keras.constraints import non_neg -from tensorflow.python.keras._impl.keras.constraints import NonNeg -from tensorflow.python.keras._impl.keras.constraints import unit_norm -from tensorflow.python.keras._impl.keras.constraints import UnitNorm +from tensorflow.python.keras.constraints import Constraint +from tensorflow.python.keras.constraints import max_norm +from tensorflow.python.keras.constraints import MaxNorm +from tensorflow.python.keras.constraints import min_max_norm +from tensorflow.python.keras.constraints import MinMaxNorm +from tensorflow.python.keras.constraints import non_neg +from tensorflow.python.keras.constraints import NonNeg +from tensorflow.python.keras.constraints import unit_norm +from tensorflow.python.keras.constraints import UnitNorm # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.constraints import deserialize -from tensorflow.python.keras._impl.keras.constraints import serialize -from tensorflow.python.keras._impl.keras.constraints import get +from tensorflow.python.keras.constraints import deserialize +from tensorflow.python.keras.constraints import serialize +from tensorflow.python.keras.constraints import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py index b5371a03fd5f5755ba8844415276113c565f52db..a5a6fdab445d2d5328f203b6a704f89e9bb4ce67 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.boston_housing import load_data +from tensorflow.python.keras.datasets.boston_housing import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py index 68d3eb789ea2c410095c0c75e0b79a9b07d209a3..e74e5f347df2eeb626cd781c54c9a7b76561d4e9 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar10 import load_data +from tensorflow.python.keras.datasets.cifar10 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py index ca93742673341660ba69712feb59c5dd32ea3252..8f5753a6360dfbddb5678c4f2c02adff86b5f0cb 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar100 import load_data +from tensorflow.python.keras.datasets.cifar100 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py index 1c6396d2d32b88eaa900a5af4e62c7484fceab63..bd6ec4b8dfb0344ad0b89956939607ef51bb0889 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.imdb import get_word_index -from tensorflow.python.keras._impl.keras.datasets.imdb import load_data +from tensorflow.python.keras.datasets.imdb import get_word_index +from tensorflow.python.keras.datasets.imdb import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py index 364255f3387b59a419c010db9b93cdfbcba36186..f61145655bd5d98965e15fecd387d538e9bc642b 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.mnist import load_data +from tensorflow.python.keras.datasets.mnist import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py index bb6791a344ad0c372ac60cd4a332f5632841dd46..ade31f4ea9c33204a4350e6bc3a5a2469e54fd61 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.reuters import get_word_index -from tensorflow.python.keras._impl.keras.datasets.reuters import load_data +from tensorflow.python.keras.datasets.reuters import get_word_index +from tensorflow.python.keras.datasets.reuters import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/initializers/__init__.py b/tensorflow/contrib/keras/api/keras/initializers/__init__.py index 6b1fcfd2d9585d19ae3fd9705e128b19b1ec40e7..c6bdc4f0dac3f446238dc4cbc72fe4be278a5ff6 100644 --- a/tensorflow/contrib/keras/api/keras/initializers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/initializers/__init__.py @@ -19,30 +19,30 @@ from __future__ import division from __future__ import print_function # Initializer functions / callable classes. -from tensorflow.python.keras._impl.keras.initializers import Constant -from tensorflow.python.keras._impl.keras.initializers import Identity -from tensorflow.python.keras._impl.keras.initializers import Initializer -from tensorflow.python.keras._impl.keras.initializers import Ones -from tensorflow.python.keras._impl.keras.initializers import Orthogonal -from tensorflow.python.keras._impl.keras.initializers import RandomNormal -from tensorflow.python.keras._impl.keras.initializers import RandomUniform -from tensorflow.python.keras._impl.keras.initializers import TruncatedNormal -from tensorflow.python.keras._impl.keras.initializers import VarianceScaling -from tensorflow.python.keras._impl.keras.initializers import Zeros +from tensorflow.python.keras.initializers import Constant +from tensorflow.python.keras.initializers import Identity +from tensorflow.python.keras.initializers import Initializer +from tensorflow.python.keras.initializers import Ones +from tensorflow.python.keras.initializers import Orthogonal +from tensorflow.python.keras.initializers import RandomNormal +from tensorflow.python.keras.initializers import RandomUniform +from tensorflow.python.keras.initializers import TruncatedNormal +from tensorflow.python.keras.initializers import VarianceScaling +from tensorflow.python.keras.initializers import Zeros # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.initializers import glorot_normal -from tensorflow.python.keras._impl.keras.initializers import glorot_uniform -from tensorflow.python.keras._impl.keras.initializers import he_normal -from tensorflow.python.keras._impl.keras.initializers import he_uniform -from tensorflow.python.keras._impl.keras.initializers import lecun_normal -from tensorflow.python.keras._impl.keras.initializers import lecun_uniform +from tensorflow.python.keras.initializers import glorot_normal +from tensorflow.python.keras.initializers import glorot_uniform +from tensorflow.python.keras.initializers import he_normal +from tensorflow.python.keras.initializers import he_uniform +from tensorflow.python.keras.initializers import lecun_normal +from tensorflow.python.keras.initializers import lecun_uniform # Auxiliary utils. -from tensorflow.python.keras._impl.keras.initializers import deserialize -from tensorflow.python.keras._impl.keras.initializers import serialize -from tensorflow.python.keras._impl.keras.initializers import get +from tensorflow.python.keras.initializers import deserialize +from tensorflow.python.keras.initializers import serialize +from tensorflow.python.keras.initializers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index acf0a5e1799b7c57dfd82861c9ccc1f132c34375..938c881fcbe18623fa18c21c112375f9914f887b 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,128 +20,128 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.engine import Input -from tensorflow.python.keras._impl.keras.engine import InputLayer -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras.engine import Input +from tensorflow.python.keras.engine import InputLayer +from tensorflow.python.keras.engine import InputSpec +from tensorflow.python.keras.engine import Layer # Advanced activations. -from tensorflow.python.keras._impl.keras.layers.advanced_activations import LeakyReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU +from tensorflow.python.keras.layers.advanced_activations import LeakyReLU +from tensorflow.python.keras.layers.advanced_activations import PReLU +from tensorflow.python.keras.layers.advanced_activations import ELU +from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU # Convolution layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D +from tensorflow.python.keras.layers.convolutional import Conv1D +from tensorflow.python.keras.layers.convolutional import Conv2D +from tensorflow.python.keras.layers.convolutional import Conv3D +from tensorflow.python.keras.layers.convolutional import Conv2DTranspose +from tensorflow.python.keras.layers.convolutional import Conv3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConv2D # Convolution layer aliases. -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D +from tensorflow.python.keras.layers.convolutional import Convolution1D +from tensorflow.python.keras.layers.convolutional import Convolution2D +from tensorflow.python.keras.layers.convolutional import Convolution3D +from tensorflow.python.keras.layers.convolutional import Convolution2DTranspose +from tensorflow.python.keras.layers.convolutional import Convolution3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConvolution2D # Image processing layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling2D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling3D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding1D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding2D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping3D +from tensorflow.python.keras.layers.convolutional import UpSampling1D +from tensorflow.python.keras.layers.convolutional import UpSampling2D +from tensorflow.python.keras.layers.convolutional import UpSampling3D +from tensorflow.python.keras.layers.convolutional import ZeroPadding1D +from tensorflow.python.keras.layers.convolutional import ZeroPadding2D +from tensorflow.python.keras.layers.convolutional import ZeroPadding3D +from tensorflow.python.keras.layers.convolutional import Cropping1D +from tensorflow.python.keras.layers.convolutional import Cropping2D +from tensorflow.python.keras.layers.convolutional import Cropping3D # Convolutional-recurrent layers. -from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import ConvLSTM2D +from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D # Core layers. -from tensorflow.python.keras._impl.keras.layers.core import Masking -from tensorflow.python.keras._impl.keras.layers.core import Dropout -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout1D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout2D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout3D -from tensorflow.python.keras._impl.keras.layers.core import Activation -from tensorflow.python.keras._impl.keras.layers.core import Reshape -from tensorflow.python.keras._impl.keras.layers.core import Permute -from tensorflow.python.keras._impl.keras.layers.core import Flatten -from tensorflow.python.keras._impl.keras.layers.core import RepeatVector -from tensorflow.python.keras._impl.keras.layers.core import Lambda -from tensorflow.python.keras._impl.keras.layers.core import Dense -from tensorflow.python.keras._impl.keras.layers.core import ActivityRegularization +from tensorflow.python.keras.layers.core import Masking +from tensorflow.python.keras.layers.core import Dropout +from tensorflow.python.keras.layers.core import SpatialDropout1D +from tensorflow.python.keras.layers.core import SpatialDropout2D +from tensorflow.python.keras.layers.core import SpatialDropout3D +from tensorflow.python.keras.layers.core import Activation +from tensorflow.python.keras.layers.core import Reshape +from tensorflow.python.keras.layers.core import Permute +from tensorflow.python.keras.layers.core import Flatten +from tensorflow.python.keras.layers.core import RepeatVector +from tensorflow.python.keras.layers.core import Lambda +from tensorflow.python.keras.layers.core import Dense +from tensorflow.python.keras.layers.core import ActivityRegularization # Embedding layers. -from tensorflow.python.keras._impl.keras.layers.embeddings import Embedding +from tensorflow.python.keras.layers.embeddings import Embedding # Locally-connected layers. -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected1D -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected2D +from tensorflow.python.keras.layers.local import LocallyConnected1D +from tensorflow.python.keras.layers.local import LocallyConnected2D # Merge layers. -from tensorflow.python.keras._impl.keras.layers.merge import Add -from tensorflow.python.keras._impl.keras.layers.merge import Multiply -from tensorflow.python.keras._impl.keras.layers.merge import Average -from tensorflow.python.keras._impl.keras.layers.merge import Maximum -from tensorflow.python.keras._impl.keras.layers.merge import Concatenate -from tensorflow.python.keras._impl.keras.layers.merge import Dot -from tensorflow.python.keras._impl.keras.layers.merge import add -from tensorflow.python.keras._impl.keras.layers.merge import multiply -from tensorflow.python.keras._impl.keras.layers.merge import average -from tensorflow.python.keras._impl.keras.layers.merge import maximum -from tensorflow.python.keras._impl.keras.layers.merge import concatenate -from tensorflow.python.keras._impl.keras.layers.merge import dot +from tensorflow.python.keras.layers.merge import Add +from tensorflow.python.keras.layers.merge import Multiply +from tensorflow.python.keras.layers.merge import Average +from tensorflow.python.keras.layers.merge import Maximum +from tensorflow.python.keras.layers.merge import Concatenate +from tensorflow.python.keras.layers.merge import Dot +from tensorflow.python.keras.layers.merge import add +from tensorflow.python.keras.layers.merge import multiply +from tensorflow.python.keras.layers.merge import average +from tensorflow.python.keras.layers.merge import maximum +from tensorflow.python.keras.layers.merge import concatenate +from tensorflow.python.keras.layers.merge import dot # Noise layers. -from tensorflow.python.keras._impl.keras.layers.noise import AlphaDropout -from tensorflow.python.keras._impl.keras.layers.noise import GaussianNoise -from tensorflow.python.keras._impl.keras.layers.noise import GaussianDropout +from tensorflow.python.keras.layers.noise import AlphaDropout +from tensorflow.python.keras.layers.noise import GaussianNoise +from tensorflow.python.keras.layers.noise import GaussianDropout # Normalization layers. -from tensorflow.python.keras._impl.keras.layers.normalization import BatchNormalization +from tensorflow.python.keras.layers.normalization import BatchNormalization # Pooling layers. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling3D +from tensorflow.python.keras.layers.pooling import MaxPooling1D +from tensorflow.python.keras.layers.pooling import MaxPooling2D +from tensorflow.python.keras.layers.pooling import MaxPooling3D +from tensorflow.python.keras.layers.pooling import AveragePooling1D +from tensorflow.python.keras.layers.pooling import AveragePooling2D +from tensorflow.python.keras.layers.pooling import AveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling1D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling2D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling3D # Pooling layer aliases. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D +from tensorflow.python.keras.layers.pooling import MaxPool1D +from tensorflow.python.keras.layers.pooling import MaxPool2D +from tensorflow.python.keras.layers.pooling import MaxPool3D +from tensorflow.python.keras.layers.pooling import AvgPool1D +from tensorflow.python.keras.layers.pooling import AvgPool2D +from tensorflow.python.keras.layers.pooling import AvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool1D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool2D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool3D # Recurrent layers. -from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN -from tensorflow.python.keras._impl.keras.layers.recurrent import GRU -from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM +from tensorflow.python.keras.layers.recurrent import SimpleRNN +from tensorflow.python.keras.layers.recurrent import GRU +from tensorflow.python.keras.layers.recurrent import LSTM # Wrapper functions -from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper -from tensorflow.python.keras._impl.keras.layers.wrappers import Bidirectional -from tensorflow.python.keras._impl.keras.layers.wrappers import TimeDistributed +from tensorflow.python.keras.layers.wrappers import Wrapper +from tensorflow.python.keras.layers.wrappers import Bidirectional +from tensorflow.python.keras.layers.wrappers import TimeDistributed del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/losses/__init__.py b/tensorflow/contrib/keras/api/keras/losses/__init__.py index 66721b694f5fd5fae7ca521ff56d4c6c6bce79b5..c4476a7bbd5056fa898468a46031bf3d8b1e44cf 100644 --- a/tensorflow/contrib/keras/api/keras/losses/__init__.py +++ b/tensorflow/contrib/keras/api/keras/losses/__init__.py @@ -19,26 +19,26 @@ from __future__ import division from __future__ import print_function # Loss functions. -from tensorflow.python.keras._impl.keras.losses import binary_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_hinge -from tensorflow.python.keras._impl.keras.losses import cosine_proximity -from tensorflow.python.keras._impl.keras.losses import hinge -from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.losses import logcosh -from tensorflow.python.keras._impl.keras.losses import mean_absolute_error -from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.losses import poisson -from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import squared_hinge +from tensorflow.python.keras.losses import binary_crossentropy +from tensorflow.python.keras.losses import categorical_crossentropy +from tensorflow.python.keras.losses import categorical_hinge +from tensorflow.python.keras.losses import cosine_proximity +from tensorflow.python.keras.losses import hinge +from tensorflow.python.keras.losses import kullback_leibler_divergence +from tensorflow.python.keras.losses import logcosh +from tensorflow.python.keras.losses import mean_absolute_error +from tensorflow.python.keras.losses import mean_absolute_percentage_error +from tensorflow.python.keras.losses import mean_squared_error +from tensorflow.python.keras.losses import mean_squared_logarithmic_error +from tensorflow.python.keras.losses import poisson +from tensorflow.python.keras.losses import sparse_categorical_crossentropy +from tensorflow.python.keras.losses import squared_hinge # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.losses import deserialize -from tensorflow.python.keras._impl.keras.losses import serialize -from tensorflow.python.keras._impl.keras.losses import get +from tensorflow.python.keras.losses import deserialize +from tensorflow.python.keras.losses import serialize +from tensorflow.python.keras.losses import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/metrics/__init__.py b/tensorflow/contrib/keras/api/keras/metrics/__init__.py index 59faf037bce0f087d244a2faaeb52713bdc3b772..7317fdb52c5b79e787a49d71be49f5261d6b1fff 100644 --- a/tensorflow/contrib/keras/api/keras/metrics/__init__.py +++ b/tensorflow/contrib/keras/api/keras/metrics/__init__.py @@ -19,28 +19,28 @@ from __future__ import division from __future__ import print_function # Metrics functions. -from tensorflow.python.keras._impl.keras.metrics import binary_accuracy -from tensorflow.python.keras._impl.keras.metrics import binary_crossentropy -from tensorflow.python.keras._impl.keras.metrics import categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import cosine_proximity -from tensorflow.python.keras._impl.keras.metrics import hinge -from tensorflow.python.keras._impl.keras.metrics import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_error -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.metrics import poisson -from tensorflow.python.keras._impl.keras.metrics import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import sparse_top_k_categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import squared_hinge -from tensorflow.python.keras._impl.keras.metrics import top_k_categorical_accuracy +from tensorflow.python.keras.metrics import binary_accuracy +from tensorflow.python.keras.metrics import binary_crossentropy +from tensorflow.python.keras.metrics import categorical_accuracy +from tensorflow.python.keras.metrics import categorical_crossentropy +from tensorflow.python.keras.metrics import cosine_proximity +from tensorflow.python.keras.metrics import hinge +from tensorflow.python.keras.metrics import kullback_leibler_divergence +from tensorflow.python.keras.metrics import mean_absolute_error +from tensorflow.python.keras.metrics import mean_absolute_percentage_error +from tensorflow.python.keras.metrics import mean_squared_error +from tensorflow.python.keras.metrics import mean_squared_logarithmic_error +from tensorflow.python.keras.metrics import poisson +from tensorflow.python.keras.metrics import sparse_categorical_crossentropy +from tensorflow.python.keras.metrics import sparse_top_k_categorical_accuracy +from tensorflow.python.keras.metrics import squared_hinge +from tensorflow.python.keras.metrics import top_k_categorical_accuracy # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.metrics import deserialize -from tensorflow.python.keras._impl.keras.metrics import serialize -from tensorflow.python.keras._impl.keras.metrics import get +from tensorflow.python.keras.metrics import deserialize +from tensorflow.python.keras.metrics import serialize +from tensorflow.python.keras.metrics import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/models/__init__.py b/tensorflow/contrib/keras/api/keras/models/__init__.py index 2fb4ac0960d38f28a1c9c897a0f1aedf57e048ac..3a196984cd88cb60fbc2a9db306ce8fecf0febc0 100644 --- a/tensorflow/contrib/keras/api/keras/models/__init__.py +++ b/tensorflow/contrib/keras/api/keras/models/__init__.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.models import load_model -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.models import model_from_config -from tensorflow.python.keras._impl.keras.models import model_from_json -from tensorflow.python.keras._impl.keras.models import model_from_yaml -from tensorflow.python.keras._impl.keras.models import save_model -from tensorflow.python.keras._impl.keras.models import Sequential +from tensorflow.python.keras.models import load_model +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.models import model_from_config +from tensorflow.python.keras.models import model_from_json +from tensorflow.python.keras.models import model_from_yaml +from tensorflow.python.keras.models import save_model +from tensorflow.python.keras.models import Sequential del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py index 44f47bc47f4a0e31aaf2ac8f67cfdbef410d8c44..4849a06747958ab41b8b6309fa848aff3da3f633 100644 --- a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py @@ -19,20 +19,20 @@ from __future__ import division from __future__ import print_function # Optimizer classes. -from tensorflow.python.keras._impl.keras.optimizers import Adadelta -from tensorflow.python.keras._impl.keras.optimizers import Adagrad -from tensorflow.python.keras._impl.keras.optimizers import Adam -from tensorflow.python.keras._impl.keras.optimizers import Adamax -from tensorflow.python.keras._impl.keras.optimizers import Nadam -from tensorflow.python.keras._impl.keras.optimizers import Optimizer -from tensorflow.python.keras._impl.keras.optimizers import RMSprop -from tensorflow.python.keras._impl.keras.optimizers import SGD +from tensorflow.python.keras.optimizers import Adadelta +from tensorflow.python.keras.optimizers import Adagrad +from tensorflow.python.keras.optimizers import Adam +from tensorflow.python.keras.optimizers import Adamax +from tensorflow.python.keras.optimizers import Nadam +from tensorflow.python.keras.optimizers import Optimizer +from tensorflow.python.keras.optimizers import RMSprop +from tensorflow.python.keras.optimizers import SGD # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.optimizers import deserialize -from tensorflow.python.keras._impl.keras.optimizers import serialize -from tensorflow.python.keras._impl.keras.optimizers import get +from tensorflow.python.keras.optimizers import deserialize +from tensorflow.python.keras.optimizers import serialize +from tensorflow.python.keras.optimizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py index b96e7675527041d3952b049f5f431d3df36eea4c..1f9e82b41bf09b235e93fa512a50ea4c3047c01b 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py @@ -18,20 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.image import apply_transform -from tensorflow.python.keras._impl.keras.preprocessing.image import array_to_img -from tensorflow.python.keras._impl.keras.preprocessing.image import DirectoryIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import flip_axis -from tensorflow.python.keras._impl.keras.preprocessing.image import ImageDataGenerator -from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array -from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator -from tensorflow.python.keras._impl.keras.preprocessing.image import load_img -from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_zoom +from tensorflow.python.keras.preprocessing.image import apply_transform +from tensorflow.python.keras.preprocessing.image import array_to_img +from tensorflow.python.keras.preprocessing.image import DirectoryIterator +from tensorflow.python.keras.preprocessing.image import flip_axis +from tensorflow.python.keras.preprocessing.image import ImageDataGenerator +from tensorflow.python.keras.preprocessing.image import img_to_array +from tensorflow.python.keras.preprocessing.image import Iterator +from tensorflow.python.keras.preprocessing.image import load_img +from tensorflow.python.keras.preprocessing.image import NumpyArrayIterator +from tensorflow.python.keras.preprocessing.image import random_channel_shift +from tensorflow.python.keras.preprocessing.image import random_rotation +from tensorflow.python.keras.preprocessing.image import random_shear +from tensorflow.python.keras.preprocessing.image import random_shift +from tensorflow.python.keras.preprocessing.image import random_zoom del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py index 112f6af5e588bcb2e85fdbecea86f402742d44e7..9a93b6fb57ff5aaab25f2b606249a6022814b5e4 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table -from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences -from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams +from tensorflow.python.keras.preprocessing.sequence import make_sampling_table +from tensorflow.python.keras.preprocessing.sequence import pad_sequences +from tensorflow.python.keras.preprocessing.sequence import skipgrams del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py index 5bf1a2fb21dc27f7aa10cd08b1496e3991c61d2f..86386a9b6762d1c5cb3915ace64686cc25367e0f 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot -from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence -from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer +from tensorflow.python.keras.preprocessing.text import one_hot +from tensorflow.python.keras.preprocessing.text import text_to_word_sequence +from tensorflow.python.keras.preprocessing.text import Tokenizer del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py index 3e707ccab577b5e28febd83d91f84d7b1c0d5d82..d668e39c09ca28239e56763f111fb01939bedc69 100644 --- a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py @@ -19,19 +19,19 @@ from __future__ import division from __future__ import print_function # Regularizer functions / callable classes. -from tensorflow.python.keras._impl.keras.regularizers import L1L2 -from tensorflow.python.keras._impl.keras.regularizers import Regularizer +from tensorflow.python.keras.regularizers import L1L2 +from tensorflow.python.keras.regularizers import Regularizer # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.regularizers import l1 -from tensorflow.python.keras._impl.keras.regularizers import l2 -from tensorflow.python.keras._impl.keras.regularizers import l1_l2 +from tensorflow.python.keras.regularizers import l1 +from tensorflow.python.keras.regularizers import l2 +from tensorflow.python.keras.regularizers import l1_l2 # Auxiliary utils. -from tensorflow.python.keras._impl.keras.regularizers import deserialize -from tensorflow.python.keras._impl.keras.regularizers import serialize -from tensorflow.python.keras._impl.keras.regularizers import get +from tensorflow.python.keras.regularizers import deserialize +from tensorflow.python.keras.regularizers import serialize +from tensorflow.python.keras.regularizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py index a7c2179fe7ad434356921a5fb8709aa5b1f33498..47cd01b924fb43e8a83836c58f8ced61e9e88268 100644 --- a/tensorflow/contrib/keras/api/keras/utils/__init__.py +++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py @@ -18,21 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence -from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer -from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope -from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix -from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.python.keras._impl.keras.utils.np_utils import normalize -from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model +from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras.utils.data_utils import get_file +from tensorflow.python.keras.utils.data_utils import Sequence +from tensorflow.python.keras.utils.data_utils import SequenceEnqueuer +from tensorflow.python.keras.utils.generic_utils import custom_object_scope +from tensorflow.python.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import get_custom_objects +from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.utils.io_utils import HDF5Matrix +from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.np_utils import normalize +from tensorflow.python.keras.utils.np_utils import to_categorical +from tensorflow.python.keras.utils.vis_utils import plot_model del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py index a46f859273ea0117e29a403057f9f81bc758dd52..c4b7aa765c26bafbfcfe45df02e58d1cf1064b4b 100644 --- a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py +++ b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasClassifier -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasRegressor +from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier +from tensorflow.python.keras.wrappers.scikit_learn import KerasRegressor del absolute_import del division diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 762a2f0b57e95e2fef3dd177070701afb410e93a..102626925db560e47cdc73eb1e25e08836cb4fba 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,5 +1,10 @@ # K-FAC: Kronecker-Factored Approximate Curvature +# WARNING: +# ==third_party/tensorflow/contrib/kfac is deprecated. This will be== +# ==removed on 15-07-2018. Please import third_party/tensorflow_kfac.== +# ==== + **K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an approximate second-order optimization method, in TensorFlow. When applied to feedforward and convolutional neural networks, K-FAC can converge `>3.5x` diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index b7f63d8d94a7a427eb57afefeda3939f0c530f8e..03b9da793307b966632789fd11162306e6cd19f9 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings + # pylint disable=long-line from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp from tensorflow.contrib.kfac.python.ops import estimator as est @@ -107,6 +109,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ + warnings.warn( + "third_party.tensorflow.contrib.kfac is deprecated." + "This will be removed on 15-07-2018. Check README for further details.", + DeprecationWarning) # Parameters to be passed to the Fisher estimator: self._variables = var_list or tf_variables.trainable_variables self._cov_ema_decay = cov_ema_decay diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 00f03a111ae8be7f49761ef5fb5a82810bcca182..bc3359693562deb1229a78a2db5c256c76f7fd8d 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -19,6 +19,8 @@ See the @{$python/contrib.layers} guide. @@avg_pool2d @@avg_pool3d @@batch_norm +@@convolution +@@convolution1d @@convolution2d @@convolution3d @@conv2d_in_plane diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index 49c3faf3b7f5eaa3b1542a1fdddcfaff99737a24..60e1d85ea9c08a51763fdaf08853f8d9b67347e5 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -458,7 +458,7 @@ def scattered_embedding_lookup_sparse(params, return embeddings -def embedding_lookup_unique(params, ids, name=None): +def embedding_lookup_unique(params, ids, partition_strategy="mod", name=None): """Version of embedding_lookup that avoids duplicate lookups. This can save communication in the case of repeated ids. @@ -470,6 +470,9 @@ def embedding_lookup_unique(params, ids, name=None): `PartitionedVariable`. Shape `[index, d1, d2, ...]`. ids: A one-dimensional `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. name: A name for this operation (optional). Returns: @@ -485,7 +488,8 @@ def embedding_lookup_unique(params, ids, name=None): ids_flat = array_ops.reshape( ids, math_ops.reduce_prod(shape, keepdims=True)) unique_ids, idx = array_ops.unique(ids_flat) - unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids) + unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids, + partition_strategy) embeds_flat = array_ops.gather(unique_embeddings, idx) embed_shape = array_ops.concat( [shape, array_ops.shape(unique_embeddings)[1:]], 0) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 2f3e57653c5d6d949c4dcc91635690322b7f90c4..b6d63c9640611abdda65f1205f544ee505dae1f0 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -57,10 +57,10 @@ from tensorflow.python.training import moving_averages __all__ = [ 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', - 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', - 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', - 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', - 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', + 'convolution1d', 'convolution2d', 'convolution2d_in_plane', + 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose', + 'dense_to_sparse', 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', + 'gdn', 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat', 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'sequence_to_images', 'softmax', 'spatial_softmax', 'stack', 'unit_norm', @@ -2022,6 +2022,7 @@ class GDN(base.Layer): def beta_initializer(shape, dtype=None, partition_info=None): del partition_info # unused + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) return math_ops.sqrt(array_ops.ones(shape, dtype=dtype) + pedestal) def gamma_initializer(shape, dtype=None, partition_info=None): @@ -2029,6 +2030,7 @@ class GDN(base.Layer): assert len(shape) == 2 assert shape[0] == shape[1] eye = linalg_ops.eye(shape[0], dtype=dtype) + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) return math_ops.sqrt(self._gamma_init * eye + pedestal) beta = self.add_variable( diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index b01fd5d5c95ac15c76f9dbe7c77f7e76f12149a9..c5c7269b1f15849956e90654e3bcf8ab0eebc393 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1312,6 +1312,29 @@ class ConvolutionInPlaneTest(test.TestCase): self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5) + def testConv1dShape(self): + width = 7 + with self.test_session(): + images = random_ops.random_uniform((5, width, 3), seed=1) + output = layers_lib.convolution1d(images, 32, 3) + self.assertEqual(output.op.name, 'Conv/Relu') + self.assertListEqual(output.get_shape().as_list(), [5, width, 32]) + + def testConvInferSpatialDims(self): + depth, height, width = 7, 9, 11 + with self.test_session(): + images = np.random.uniform(size=(5, width, 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3]) + self.assertListEqual(output.get_shape().as_list(), [5, width, 32]) + images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3, 3]) + self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32]) + images = np.random.uniform(size=(5, depth, height, width, + 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3, 3, 3]) + self.assertListEqual(output.get_shape().as_list(), + [5, depth, height, width, 32]) + class DenseToSparseTest(test.TestCase): @@ -1333,7 +1356,7 @@ class DropoutTest(test.TestCase): with self.test_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.dropout(images) - self.assertEqual(output.op.name, 'Dropout/dropout/mul') + self.assertEqual(output.op.name, 'Dropout/dropout_1/mul') output.get_shape().assert_is_compatible_with( ops.convert_to_tensor(images).get_shape()) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 8ed9f446bcd5f222f486e43125dafc595852e5ce..0e35b1aa8bf682c1b4f7e8d974d3e8fad69e33cb 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -46,6 +46,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect __all__ = ["rev_block", "RevBlock", "recompute_grad"] @@ -449,6 +450,15 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): `variable_scope(name, use_resource=True), which are the default in Eager mode and when running on TPU. + Warning: Because the function will be called again on the backwards pass, the + user should be careful to not use ops in their function that mutate state or + have randomness (for example, batch normalization or dropout). If the function + does have such operations, it is recommended that the function take the + `is_recomputing` keyword argument which will be `False` on the forward pass + and `True` on the backwards pass so that it can disable state changes when + `is_recomputing=True` (for example, not updating the moving averages in batch + normalization). + Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. @@ -482,6 +492,7 @@ def _is_on_tpu(): def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """See recompute_grad.""" + has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args for arg in args: if not isinstance(arg, framework_ops.Tensor): raise ValueError("All inputs to function must be Tensors") @@ -496,7 +507,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): vs = variable_scope.get_variable_scope() arg_scope = contrib_framework_ops.current_arg_scope() with backprop.GradientTape() as tape: - outputs = fn(*args) + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = False + outputs = fn(*args, **fn_kwargs) original_vars = set(tape.watched_variables()) # Backward pass @@ -516,7 +530,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): with contrib_framework_ops.arg_scope(arg_scope): with variable_scope.variable_scope(vs, reuse=True): with backprop.GradientTape() as tape: - outputs = fn(*inputs) + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = True + outputs = fn(*inputs, **fn_kwargs) recompute_vars = set(tape.watched_variables()) if original_vars != recompute_vars: raise ValueError(_WRONG_VARS_ERR) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index 997f53b9e1bbf9ac151cadd4a9f8e79c2e0ebca2..bc09ba8d439808c1582f207a99504012afcf33a6 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -21,9 +21,11 @@ from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.layers import rev_block_lib from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.layers import convolutional from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import normalization as normalization_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops @@ -342,6 +344,34 @@ class RecomputeTest(test.TestCase): for grad in grads: self.assertTrue(grad is not None) + def testWithIsRecomputeKwarg(self): + + kwarg_values = [] + + @rev_block_lib.recompute_grad + def layer_with_recompute(inputs, is_recomputing=False): + kwarg_values.append(is_recomputing) + out = core_layers.dense(inputs, 2) + out = normalization_layers.batch_normalization(out, training=True) + if is_recomputing: + # Ensure that the updates are not duplicated by popping off the latest + # 2 additions. + update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS) + update_ops.pop() + update_ops.pop() + return out + + x = array_ops.ones((2, 4), dtypes.float32) + with variable_scope.variable_scope("layer1", use_resource=True): + y = layer_with_recompute(x) + loss = math_ops.reduce_sum(y) + tvars = variables.trainable_variables() + gradients_impl.gradients(loss, [x] + tvars) + + update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS) + self.assertEqual(2, len(update_ops)) + self.assertEqual([False, True], kwarg_values) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 4a360711f834354ce77b7a9579c05780a72c2661..b56a88659bbd4467600788fc8e3e9dbf38ce8244 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -284,6 +284,7 @@ py_test( tags = [ "manual", "noasan", # times out + "optonly", # test is flaky without optimization. ], deps = [ ":learn", @@ -434,6 +435,7 @@ py_test( name = "kmeans_test", size = "medium", srcs = ["python/learn/estimators/kmeans_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = [ "noasan", # b/73741358 @@ -745,7 +747,7 @@ py_test( tf_py_test( name = "graph_io_test", - size = "small", + size = "medium", srcs = ["python/learn/learn_io/graph_io_test.py"], additional_deps = [ ":learn", diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index e28e6854a5097d66cb486be3e82f3726f5cc70fd..339c4e0e360ed9ef9906f0e51b64a0dc13826259 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -1862,12 +1862,12 @@ def _get_arguments(func): if hasattr(func, "__code__"): # Regular function. return tf_inspect.getargspec(func) - elif hasattr(func, "__call__"): - # Callable object. - return _get_arguments(func.__call__) elif hasattr(func, "func"): # Partial function. return _get_arguments(func.func) + elif hasattr(func, "__call__"): + # Callable object. + return _get_arguments(func.__call__) def _verify_loss_fn_args(loss_fn): diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 70b70af98c51dcb991c19152607272673953ee2a..e100bc7a1e7be4896e9ab1c965775b5185b38897 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -31,7 +31,6 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib @@ -51,6 +50,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import session_run_hook from tensorflow.python.training import training as train +from tensorflow.python.training import training_util # The default learning rate of 0.2 is a historical artifact of the initial @@ -244,7 +244,9 @@ def sdca_model_fn(features, labels, mode, params): parent_scope = "linear" with variable_scope.variable_scope( - values=features.values(), name_or_scope=parent_scope) as scope: + values=features.values(), + name_or_scope=parent_scope, + partitioner=optimizer.partitioner) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index 0a863f0e20c05d3372ffd8f7677cd518390ecc9d..597ca4e86dbf66c86182f14a2a364b662d52fb0a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -43,6 +43,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test from tensorflow.python.training import ftrl from tensorflow.python.training import input as input_lib @@ -966,6 +967,63 @@ class LinearClassifierTest(test.TestCase): scores = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerPartitionedVariables(self): + """Tests LinearClassifier with SDCAOptimizer with partitioned variables.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [1.0], [1.0]]) + }, constant_op.constant([[1], [0], [1]]) + + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + + tf_config = { + 'cluster': { + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + config = run_config.RunConfig() + # Because we did not start a distributed cluster, we need to pass an + # empty ClusterSpec, otherwise the device_setter will look for + # distributed jobs, such as "/job:ps" which are not present. + config._cluster_spec = server_lib.ClusterSpec({}) + + classifier = linear.LinearClassifier( + feature_columns=[price, sq_footage_bucket, country, sq_footage_country], + weight_column_name='weights', + optimizer=sdca_optimizer, + config=config) + classifier.fit(input_fn=input_fn, steps=50) + scores = classifier.evaluate(input_fn=input_fn, steps=1) + print('all scores = {}'.format(scores)) + self.assertGreater(scores['accuracy'], 0.9) + def testEval(self): """Tests that eval produces correct metrics. """ @@ -1540,6 +1598,60 @@ class LinearRegressorTest(test.TestCase): loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] self.assertLess(loss, 0.05) + def testSdcaOptimizerPartitionedVariables(self): + """Tests LinearRegressor with SDCAOptimizer with partitioned variables.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([0.6, 0.8, 0.3]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [5.0], [7.0]]) + }, constant_op.constant([[1.55], [-1.25], [-3.0]]) + + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id', symmetric_l2_regularization=1.0, + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + tf_config = { + 'cluster': { + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + config = run_config.RunConfig() + # Because we did not start a distributed cluster, we need to pass an + # empty ClusterSpec, otherwise the device_setter will look for + # distributed jobs, such as "/job:ps" which are not present. + config._cluster_spec = server_lib.ClusterSpec({}) + + regressor = linear.LinearRegressor( + feature_columns=[price, sq_footage_bucket, country, sq_footage_country], + weight_column_name='weights', + optimizer=sdca_optimizer, + config=config) + regressor.fit(input_fn=input_fn, steps=20) + loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] + self.assertLess(loss, 0.05) + def testSdcaOptimizerSparseFeaturesWithL1Reg(self): """Tests LinearClassifier with SDCAOptimizer and sparse features.""" diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index dfc6a393d069fccb0fa93dc265f744e199db0dcf..f8a3709ee57a32734afa7ac8133271c75d152b2c 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -38,19 +38,19 @@ from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.python.estimator import estimator as core_estimator -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import saver from tensorflow.python.training import server_lib from tensorflow.python.util import compat +from tensorflow.python.util import function_utils __all__ = ["Experiment"] def _get_standardized_predicate_fn(predicate_fn): - pred_fn_args = estimator_util.fn_args(predicate_fn) + pred_fn_args = function_utils.fn_args(predicate_fn) if "checkpoint_path" not in pred_fn_args: # pylint: disable=unused-argument def _pred_fn_wrapper(eval_results, checkpoint_path): @@ -505,7 +505,7 @@ class Experiment(object): eval_result = None last_warning_time = 0 while (not predicate_fn or predicate_fn( - eval_result, checkpoint_path=previous_path if eval_result else None)): + eval_result, checkpoint_path=previous_path)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index d10927a0cdd5c67c8d2a8e569153235ee175ec4d..fb16c94c29660e2777942ea9cf30da51dbf90571 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -500,7 +500,7 @@ class ExperimentTest(test.TestCase): noop_hook = _NoopHook() def _predicate_fn(eval_result, checkpoint_path): - self.assertEqual(not eval_result, + self.assertEqual(eval_result is None, checkpoint_path is None) return est.eval_count < 3 # pylint: disable=cell-var-from-loop diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index b5741967ab52568725d7c9f03a0cc0b0f63f7459..ef0e08a777779e04f70d11fe83280ccaf1c178fd 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -35,6 +35,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import googletest @@ -132,15 +134,22 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero): return examples_dict, variables_dict -def make_variable_dict(max_age, max_gender): +def make_variable_dict(max_age, max_gender, partitioned=False): # TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from # examples_dict. - age_weights = variables_lib.Variable( - array_ops.zeros( - [max_age + 1], dtype=dtypes.float32)) - gender_weights = variables_lib.Variable( - array_ops.zeros( - [max_gender + 1], dtype=dtypes.float32)) + partitioner = None + if partitioned: + partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2, + axis=0) + with variable_scope.variable_scope( + name_or_scope='variables', + partitioner=partitioner): + age_weights = variables_lib.Variable( + array_ops.zeros( + [max_age + 1], dtype=dtypes.float32)) + gender_weights = variables_lib.Variable( + array_ops.zeros( + [max_gender + 1], dtype=dtypes.float32)) return dict( sparse_features_weights=[age_weights, gender_weights], dense_features_weights=[]) @@ -265,6 +274,54 @@ class SdcaWithLogisticLossTest(SdcaModelTest): self.assertAllClose( 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + def testPartitionedPrimals(self): + # Setup test data + example_protos = [ + make_example_proto({ + 'age': [0], + 'gender': [0] + }, 0), + make_example_proto({ + 'age': [1], + 'gender': [1] + }, 1), + ] + example_weights = [1.0, 1.0] + for num_shards in _SHARD_NUMBERS: + with self._single_threaded_test_session(): + examples = make_example_dict(example_protos, example_weights) + variables = make_variable_dict(1, 1, partitioned=True) + options = dict( + symmetric_l2_regularization=1, + symmetric_l1_regularization=0, + num_table_shards=num_shards, + loss_type='logistic_loss') + + lr = SdcaModel(examples, variables, options) + variables_lib.global_variables_initializer().run() + unregularized_loss = lr.unregularized_loss(examples) + loss = lr.regularized_loss(examples) + predictions = lr.predictions(examples) + self.assertAllClose(0.693147, unregularized_loss.eval()) + self.assertAllClose(0.693147, loss.eval()) + train_op = lr.minimize() + for _ in range(_MAX_ITERATIONS): + train_op.run() + lr.update_weights(train_op).run() + # The high tolerance in unregularized_loss comparisons is due to the + # fact that it's possible to trade off unregularized_loss vs. + # regularization and still have a sum that is quite close to the + # optimal regularized_loss value. SDCA's duality gap only ensures that + # the regularized_loss is within 0.01 of optimal. + # 0.525457 is the optimal regularized_loss. + # 0.411608 is the unregularized_loss at that optimum. + self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05) + self.assertAllClose(0.525457, loss.eval(), atol=0.01) + predicted_labels = get_binary_predictions_for_logistic(predictions) + self.assertAllEqual([0, 1], predicted_labels.eval()) + self.assertAllClose( + 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + def testSparseRandom(self): dim = 20 num_examples = 1000 @@ -320,7 +377,10 @@ class SdcaWithLogisticLossTest(SdcaModelTest): train_op.run() def testDistributedSimple(self): - # Setup test data + # Distributed SDCA may not converge if the workers update concurrently the + # same example. In this test the examples are partitioned across workers. + # The examples are the same for all workers, just the example_ids are + # different. example_protos = [ make_example_proto({ 'age': [0], @@ -332,13 +392,19 @@ class SdcaWithLogisticLossTest(SdcaModelTest): }, 1), ] example_weights = [1.0, 1.0] + examples = make_example_dict(example_protos, example_weights) + example_ids = array_ops.placeholder( + dtypes.string, shape=(len(example_weights),)) + examples['example_ids'] = example_ids + variables = make_variable_dict(1, 1) for num_shards in _SHARD_NUMBERS: for num_loss_partitions in _NUM_LOSS_PARTITIONS: with self._single_threaded_test_session(): - examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) options = dict( - symmetric_l2_regularization=1, + # Keep the same solution as for TestSimple: since the number of + # examples is multplied by num_loss_partitions, multiply also + # L2 by the same value. + symmetric_l2_regularization=num_loss_partitions, symmetric_l1_regularization=0, loss_type='logistic_loss', num_table_shards=num_shards, @@ -354,32 +420,30 @@ class SdcaWithLogisticLossTest(SdcaModelTest): train_op = lr.minimize() - def minimize(): + def minimize(worker_id): with self._single_threaded_test_session(): + feed_dict = {example_ids: [ + str(i + worker_id*len(example_weights)) for i in range( + len(example_weights))]} for _ in range(_MAX_ITERATIONS): - train_op.run() # pylint: disable=cell-var-from-loop + train_op.run(feed_dict=feed_dict) # pylint: disable=cell-var-from-loop threads = [] - for _ in range(num_loss_partitions): - threads.append(threading.Thread(target=minimize)) + for worker_id in range(num_loss_partitions): + threads.append(threading.Thread(target=minimize, args=(worker_id,))) threads[-1].start() for t in threads: t.join() - lr.update_weights(train_op).run() - - # The high tolerance in unregularized_loss comparisons is due to the - # fact that it's possible to trade off unregularized_loss vs. - # regularization and still have a sum that is quite close to the - # optimal regularized_loss value. SDCA's duality gap only ensures - # that the regularized_loss is within 0.01 of optimal. - # 0.525457 is the optimal regularized_loss. - # 0.411608 is the unregularized_loss at that optimum. - self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05) - self.assertAllClose(0.525457, loss.eval(), atol=0.01) + lr.update_weights(train_op).run(feed_dict={ + example_ids: [str(i) for i in range(len(example_weights))]}) + + # Test only the unregularized loss because the optimal value of the + # regularized loss depends on num_loss_partitions. + self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.02) predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllEqual([0, 1], predicted_labels.eval()) - self.assertTrue(lr.approximate_duality_gap().eval() < 0.02) + self.assertNear(0.0, lr.approximate_duality_gap().eval(), 0.02) def testSimpleNoL2(self): # Same as test above (so comments from above apply) but without an L2. diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index f980746a19fb8e0a02b9d023c127da7ab33e457f..0047d5753a773ce814d685f89da9ae6b04d21cb6 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -22,12 +22,14 @@ import collections from six.moves import range from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework.ops import internal_convert_to_tensor from tensorflow.python.framework.ops import name_scope from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -43,9 +45,6 @@ __all__ = ['SdcaModel'] class SdcaModel(object): """Stochastic dual coordinate ascent solver for linear models. - This class currently only supports a single machine (multi-threaded) - implementation. We expect the weights and duals to fit in a single machine. - Loss functions supported: * Binary logistic loss @@ -182,18 +181,41 @@ class SdcaModel(object): # TODO(sibyl-Aix6ihai): Use optimizer interface to make use of slot creation logic. def _create_slots(self): - # Make internal variables which have the updates before applying L1 - # regularization. + """Make unshrinked internal variables (slots).""" + # Unshrinked variables have the updates before applying L1 regularization. + # Each unshrinked slot variable is either a `Variable` or list of + # `Variable`, depending on the value of its corresponding primary variable. + # We avoid using `PartitionedVariable` for the unshrinked slots since we do + # not need any of the extra information. self._slots = collections.defaultdict(list) for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: - with ops.device(var.device): - # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is - # fixed - self._slots['unshrinked_' + name].append( - var_ops.Variable( - array_ops.zeros_like(var.initialized_value(), dtypes.float32), - name=var.op.name + '_unshrinked/SDCAOptimizer')) + # Our primary variable may be either a PartitionedVariable, or a list + # of Variables (each representing a partition). + if (isinstance(var, var_ops.PartitionedVariable) or + isinstance(var, list)): + var_list = [] + # pylint: disable=protected-access + for v in var: + with ops.colocate_with(v): + # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 + # is fixed. + slot_var = var_ops.Variable( + initial_value=array_ops.zeros_like(v.initialized_value(), + dtypes.float32), + name=v.op.name + '_unshrinked/SDCAOptimizer') + var_list.append(slot_var) + self._slots['unshrinked_' + name].append(var_list) + # pylint: enable=protected-access + else: + with ops.device(var.device): + # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is + # fixed. + self._slots['unshrinked_' + name].append( + var_ops.Variable( + array_ops.zeros_like(var.initialized_value(), + dtypes.float32), + name=var.op.name + '_unshrinked/SDCAOptimizer')) def _assertSpecified(self, items, check_in): for x in items: @@ -205,16 +227,25 @@ class SdcaModel(object): if not isinstance(check_in[x], list): raise ValueError(x + ' must be a list.') + def _var_to_list(self, var): + """Wraps var in a list if it is not a list or PartitionedVariable.""" + if not (isinstance(var, list) or + isinstance(var, var_ops.PartitionedVariable)): + var = [var] + return var + def _l1_loss(self): """Computes the (un-normalized) l1 loss of the model.""" with name_scope('sdca/l1_loss'): sums = [] for name in ['sparse_features_weights', 'dense_features_weights']: - for weights in self._convert_n_to_tensor(self._variables[name]): - with ops.device(weights.device): - sums.append( - math_ops.reduce_sum( - math_ops.abs(math_ops.cast(weights, dtypes.float64)))) + for var in self._variables[name]: + for v in self._var_to_list(var): + weights = internal_convert_to_tensor(v) + with ops.device(weights.device): + sums.append( + math_ops.reduce_sum( + math_ops.abs(math_ops.cast(weights, dtypes.float64)))) # SDCA L1 regularization cost is: l1 * sum(|weights|) return self._options['symmetric_l1_regularization'] * math_ops.add_n(sums) @@ -223,17 +254,37 @@ class SdcaModel(object): with name_scope('sdca/l2_loss'): sums = [] for name in ['sparse_features_weights', 'dense_features_weights']: - for weights in self._convert_n_to_tensor(self._variables[name]): - with ops.device(weights.device): - sums.append( - math_ops.reduce_sum( - math_ops.square(math_ops.cast(weights, dtypes.float64)))) + for var in self._variables[name]: + for v in self._var_to_list(var): + weights = internal_convert_to_tensor(v) + with ops.device(weights.device): + sums.append(math_ops.reduce_sum(math_ops.square(math_ops.cast( + weights, dtypes.float64)))) # SDCA L2 regularization cost is: l2 * sum(weights^2) / 2 return l2 * math_ops.add_n(sums) / 2.0 def _convert_n_to_tensor(self, input_list, as_ref=False): """Converts input list to a set of tensors.""" - return [internal_convert_to_tensor(x, as_ref=as_ref) for x in input_list] + # input_list can be a list of Variables (that are implicitly partitioned), + # in which case the underlying logic in internal_convert_to_tensor will not + # concatenate the partitions together. This method takes care of the + # concatenating (we only allow partitioning on the first axis). + output_list = [] + for x in input_list: + tensor_to_convert = x + if isinstance(x, list) or isinstance(x, var_ops.PartitionedVariable): + # We only allow for partitioning on the first axis. + tensor_to_convert = array_ops.concat(x, axis=0) + output_list.append(internal_convert_to_tensor( + tensor_to_convert, as_ref=as_ref)) + return output_list + + def _get_first_dimension_size_statically(self, w, num_partitions): + """Compute the static size of the first dimension for a sharded variable.""" + dim_0_size = w[0].get_shape()[0] + for p in range(1, num_partitions): + dim_0_size += w[p].get_shape()[0] + return dim_0_size def _linear_predictions(self, examples): """Returns predictions of the form w*x.""" @@ -286,6 +337,28 @@ class SdcaModel(object): result = math_ops.sigmoid(result) return result + def _get_partitioned_update_ops(self, + v_num, + num_partitions_by_var, + p_assignments_by_var, + gather_ids_by_var, + weights, + full_update, + p_assignments, + num_partitions): + """Get updates for partitioned variables.""" + num_partitions = num_partitions_by_var[v_num] + p_assignments = p_assignments_by_var[v_num] + gather_ids = gather_ids_by_var[v_num] + updates = data_flow_ops.dynamic_partition( + full_update, p_assignments, num_partitions) + update_ops = [] + for p in range(num_partitions): + with ops.colocate_with(weights[p]): + result = state_ops.scatter_add(weights[p], gather_ids[p], updates[p]) + update_ops.append(result) + return update_ops + def minimize(self, global_step=None, name=None): """Add operations to train a linear model by minimizing the loss function. @@ -318,18 +391,89 @@ class SdcaModel(object): # Solver returns example_state_update, new delta sparse_feature_weights # and delta dense_feature_weights. - weights_tensor = self._convert_n_to_tensor(self._slots[ - 'unshrinked_sparse_features_weights']) sparse_weights = [] sparse_indices = [] - for w, i in zip(weights_tensor, sparse_feature_indices): - # Find the feature ids to lookup in the variables. - with ops.device(w.device): - sparse_indices.append( - math_ops.cast( - array_ops.unique(math_ops.cast(i, dtypes.int32))[0], - dtypes.int64)) - sparse_weights.append(array_ops.gather(w, sparse_indices[-1])) + # If we have partitioned variables, keep a few lists of Tensors around + # that we need for the assign_add after the op call to + # gen_sdca_ops.sdca_optimizer(). + num_partitions_by_var = [] + p_assignments_by_var = [] + gather_ids_by_var = [] + for w, i in zip(self._slots['unshrinked_sparse_features_weights'], + sparse_feature_indices): + # Append the sparse_indices (in full-variable space). + sparse_idx = math_ops.cast( + array_ops.unique(math_ops.cast(i, dtypes.int32))[0], + dtypes.int64) + sparse_indices.append(sparse_idx) + if isinstance(w, list) or isinstance(w, var_ops.PartitionedVariable): + num_partitions = len(w) + flat_ids = array_ops.reshape(sparse_idx, [-1]) + # We use div partitioning, which is easiest to support downstream. + # Compute num_total_ids as the sum of dim-0 of w, then assign + # to partitions based on a constant number of ids per partition. + # Optimize if we already know the full shape statically. + dim_0_size = self._get_first_dimension_size_statically( + w, num_partitions) + + if dim_0_size.value: + num_total_ids = constant_op.constant(dim_0_size.value, + flat_ids.dtype) + else: + dim_0_sizes = [] + for p in range(num_partitions): + if w[p].get_shape()[0].value is not None: + dim_0_sizes.append(w[p].get_shape()[0].value) + else: + with ops.colocate_with(w[p]): + dim_0_sizes.append(array_ops.shape(w[p])[0]) + num_total_ids = math_ops.reduce_sum( + math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) + ids_per_partition = num_total_ids // num_partitions + extras = num_total_ids % num_partitions + + p_assignments = math_ops.maximum( + flat_ids // (ids_per_partition + 1), + (flat_ids - extras) // ids_per_partition) + + # Emulate a conditional using a boolean indicator tensor + new_ids = array_ops.where(p_assignments < extras, + flat_ids % (ids_per_partition + 1), + (flat_ids - extras) % ids_per_partition) + + # Cast partition assignments to int32 for use in dynamic_partition. + # There really should not be more than 2^32 partitions. + p_assignments = math_ops.cast(p_assignments, dtypes.int32) + # Partition list of ids based on assignments into num_partitions + # separate lists. + gather_ids = data_flow_ops.dynamic_partition(new_ids, + p_assignments, + num_partitions) + # Append these to the lists for use in the later update. + num_partitions_by_var.append(num_partitions) + p_assignments_by_var.append(p_assignments) + gather_ids_by_var.append(gather_ids) + + # Gather the weights from each partition. + partition_gathered_weights = [] + for p in range(num_partitions): + with ops.colocate_with(w[p]): + partition_gathered_weights.append( + array_ops.gather(w[p], gather_ids[p])) + + # Stitch the weights back together in the same order they were before + # we dynamic_partitioned them. + condition_indices = data_flow_ops.dynamic_partition( + math_ops.range(array_ops.shape(new_ids)[0]), + p_assignments, num_partitions) + batch_gathered_weights = data_flow_ops.dynamic_stitch( + condition_indices, partition_gathered_weights) + else: + w_as_tensor = internal_convert_to_tensor(w) + with ops.device(w_as_tensor.device): + batch_gathered_weights = array_ops.gather( + w_as_tensor, sparse_idx) + sparse_weights.append(batch_gathered_weights) # pylint: disable=protected-access esu, sfw, dfw = gen_sdca_ops.sdca_optimizer( @@ -355,12 +499,25 @@ class SdcaModel(object): with ops.control_dependencies([esu]): update_ops = [self._hashtable.insert(example_ids_hashed, esu)] # Update the weights before the proximal step. - for w, i, u in zip(self._slots['unshrinked_sparse_features_weights'], - sparse_indices, sfw): - update_ops.append(state_ops.scatter_add(w, i, u)) + for v_num, (w, i, u) in enumerate( + zip(self._slots['unshrinked_sparse_features_weights'], + sparse_indices, sfw)): + if (isinstance(w, var_ops.PartitionedVariable) or + isinstance(w, list)): + update_ops += self._get_partitioned_update_ops( + v_num, num_partitions_by_var, p_assignments_by_var, + gather_ids_by_var, w, u, p_assignments, num_partitions) + else: + update_ops.append(state_ops.scatter_add(w, i, u)) for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw): - update_ops.append(w.assign_add(u)) - + if (isinstance(w, var_ops.PartitionedVariable) or + isinstance(w, list)): + split_updates = array_ops.split( + u, num_or_size_splits=[v.shape.as_list()[0] for v in w]) + for v, split_update in zip(w, split_updates): + update_ops.append(state_ops.assign_add(v, split_update)) + else: + update_ops.append(state_ops.assign_add(w, u)) if not global_step: return control_flow_ops.group(*update_ops) with ops.control_dependencies(update_ops): @@ -385,21 +542,22 @@ class SdcaModel(object): for name in ['sparse_features_weights', 'dense_features_weights']: for var, slot_var in zip(self._variables[name], self._slots['unshrinked_' + name]): - update_ops.append(var.assign(slot_var)) + for v, sv in zip(self._var_to_list(var), self._var_to_list(slot_var)): + update_ops.append(v.assign(sv)) # Apply proximal step. with ops.control_dependencies(update_ops): update_ops = [] for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: - with ops.device(var.device): - # pylint: disable=protected-access - update_ops.append( - gen_sdca_ops.sdca_shrink_l1( - self._convert_n_to_tensor( - [var], as_ref=True), - l1=self._symmetric_l1_regularization(), - l2=self._symmetric_l2_regularization())) + for v in self._var_to_list(var): + with ops.device(v.device): + # pylint: disable=protected-access + update_ops.append( + gen_sdca_ops.sdca_shrink_l1( + self._convert_n_to_tensor([v], as_ref=True), + l1=self._symmetric_l1_regularization(), + l2=self._symmetric_l2_regularization())) return control_flow_ops.group(*update_ops) def approximate_duality_gap(self): diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index d4e54c82f988e0adcd16aad29702ee9f8b16aea3..200e7de6b95f17672c6ef51f887b15f9d185f775 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -116,6 +116,7 @@ def sdca_model_fn(features, labels, mode, params, config=None): num_loss_partitions = params["num_loss_partitions"] weight_column_name = params["weight_column_name"] update_weights_hook = params.get("update_weights_hook", None) + partitioner = params["partitioner"] loss_type = None if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access @@ -136,12 +137,14 @@ def sdca_model_fn(features, labels, mode, params, config=None): example_id_column=example_id_column, num_loss_partitions=n_loss_partitions, symmetric_l1_regularization=l1_regularization, - symmetric_l2_regularization=l2_regularization) + symmetric_l2_regularization=l2_regularization, + partitioner=partitioner) parent_scope = "linear" with variable_scope.variable_scope( - values=features.values(), name_or_scope=parent_scope) as scope: + values=features.values(), name_or_scope=parent_scope, + partitioner=partitioner) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( @@ -213,7 +216,8 @@ class _SDCAEstimator(estimator.Estimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `_SDCAEstimator` estimator object. Args: @@ -241,6 +245,8 @@ class _SDCAEstimator(estimator.Estimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `_SDCAEstimator` estimator. @@ -267,6 +273,7 @@ class _SDCAEstimator(estimator.Estimator): "l2_regularization": l2_regularization, "weight_column_name": weight_column_name, "update_weights_hook": _SdcaUpdateWeightsHook(), + "partitioner": partitioner, } super(_SDCAEstimator, self).__init__( @@ -336,7 +343,8 @@ class SDCALogisticClassifier(_SDCAEstimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `SDCALogisticClassifier` object. Args: @@ -361,6 +369,8 @@ class SDCALogisticClassifier(_SDCAEstimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `SDCALogisiticClassifier` estimator. @@ -376,7 +386,8 @@ class SDCALogisticClassifier(_SDCAEstimator): l2_regularization=l2_regularization, num_loss_partitions=num_loss_partitions, config=config, - feature_engineering_fn=None) + feature_engineering_fn=None, + partitioner=partitioner) def predict_classes(self, input_fn=None): """Runs inference to determine the predicted class. @@ -463,7 +474,8 @@ class SDCALinearRegressor(_SDCAEstimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `SDCALinearRegressor` estimator object. @@ -489,6 +501,8 @@ class SDCALinearRegressor(_SDCAEstimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `SDCALinearRegressor` estimator. @@ -503,7 +517,8 @@ class SDCALinearRegressor(_SDCAEstimator): l2_regularization=l2_regularization, num_loss_partitions=num_loss_partitions, config=config, - feature_engineering_fn=None) + feature_engineering_fn=None, + partitioner=partitioner) def predict_scores(self, input_fn): """Returns predicted scores for given features. diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py index bed3d5139fcbf9d9e8b85605c752736f26af6793..647667188238dc18b137eaad98356a79b3a549b4 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.linear_optimizer.python import sdca_estimator from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test @@ -273,6 +274,47 @@ class SDCALogisticClassifierTest(test.TestCase): metrics = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(metrics['accuracy'], 0.9) + def testPartitionedMixedFeatures(self): + """Tests SDCALogisticClassifier with a mix of features (partitioned).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([900.0, 700.0, 600.0]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [1.0], [1.0]]) + }, constant_op.constant([[1], [0], [1]]) + + with self._single_threaded_test_session(): + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + classifier = sdca_estimator.SDCALogisticClassifier( + example_id_column='example_id', + feature_columns=[ + price, sq_footage_bucket, country, sq_footage_country + ], + weight_column_name='weights', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + classifier.fit(input_fn=input_fn, steps=50) + metrics = classifier.evaluate(input_fn=input_fn, steps=1) + self.assertGreater(metrics['accuracy'], 0.9) + class SDCALinearRegressorTest(test.TestCase): @@ -350,6 +392,48 @@ class SDCALinearRegressorTest(test.TestCase): loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] self.assertLess(loss, 0.05) + def testMixedFeaturesArbitraryWeightsPartitioned(self): + """Tests SDCALinearRegressor works with a mix of features (partitioned).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [5.0], [7.0]]) + }, constant_op.constant([[1.55], [-1.25], [-3.0]]) + + with self._single_threaded_test_session(): + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + regressor = sdca_estimator.SDCALinearRegressor( + example_id_column='example_id', + feature_columns=[ + price, sq_footage_bucket, country, sq_footage_country + ], + l2_regularization=1.0, + weight_column_name='weights', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + regressor.fit(input_fn=input_fn, steps=20) + loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] + self.assertLess(loss, 0.05) + def testSdcaOptimizerSparseFeaturesWithL1Reg(self): """SDCALinearRegressor works with sparse features and L1 regularization.""" diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 12039ecc6f357af07e0c2a08e17d46396f3ad386..9872c6f97c879d8994b6c26e65df33e368a0603e 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -64,7 +64,8 @@ class SDCAOptimizer(object): of workers running the train steps. It defaults to 1 (single machine). `num_table_shards` defines the number of shards for the internal state table, typically set to match the number of parameter servers for large - data sets. + data sets. You can also specify a `partitioner` object to partition the primal + weights during training (`div` partitioning strategy will be used). """ def __init__(self, @@ -73,13 +74,15 @@ class SDCAOptimizer(object): num_table_shards=None, symmetric_l1_regularization=0.0, symmetric_l2_regularization=1.0, - adaptive=True): + adaptive=True, + partitioner=None): self._example_id_column = example_id_column self._num_loss_partitions = num_loss_partitions self._num_table_shards = num_table_shards self._symmetric_l1_regularization = symmetric_l1_regularization self._symmetric_l2_regularization = symmetric_l2_regularization self._adaptive = adaptive + self._partitioner = partitioner def get_name(self): return 'SDCAOptimizer' @@ -108,6 +111,10 @@ class SDCAOptimizer(object): def adaptive(self): return self._adaptive + @property + def partitioner(self): + return self._partitioner + def get_train_step(self, columns_to_variables, weight_column_name, loss_type, features, targets, global_step): """Returns the training operation of an SdcaModel optimizer.""" @@ -175,10 +182,12 @@ class SDCAOptimizer(object): sparse_feature_column = _dense_tensor_to_sparse_feature_column( dense_bucket_tensor) sparse_feature_with_values.append(sparse_feature_column) - # For bucketized columns, the variables list contains exactly one - # element. - sparse_feature_with_values_weights.append( - columns_to_variables[column][0]) + # If a partitioner was used during variable creation, we will have a + # list of Variables here larger than 1. + vars_to_append = columns_to_variables[column][0] + if len(columns_to_variables[column]) > 1: + vars_to_append = columns_to_variables[column] + sparse_feature_with_values_weights.append(vars_to_append) elif isinstance( column, ( @@ -226,8 +235,12 @@ class SDCAOptimizer(object): array_ops.shape(ids)[0]), [-1]) sparse_feature_with_values.append( SparseFeatureColumn(example_ids_filtered, reproject_ids, weights)) - sparse_feature_with_values_weights.append( - columns_to_variables[column][0]) + # If a partitioner was used during variable creation, we will have a + # list of Variables here larger than 1. + vars_to_append = columns_to_variables[column][0] + if len(columns_to_variables[column]) > 1: + vars_to_append = columns_to_variables[column] + sparse_feature_with_values_weights.append(vars_to_append) else: raise ValueError('SDCAOptimizer does not support column type %s.' % type(column).__name__) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index f0e7005d66cd7921a15837f44343907a97d43812..9c804d27854b8004d34c65691b48ca2b0d3bbf7c 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -90,6 +90,16 @@ cc_library( deps = [":context"], ) +cc_library( + name = "kernel_api", + hdrs = [ + "builtin_op_data.h", + "builtin_ops.h", + "context.h", + "context_util.h", + ], +) + exports_files(["builtin_ops.h"]) cc_library( @@ -112,6 +122,7 @@ cc_library( "interpreter.cc", "model.cc", "nnapi_delegate.cc", + "op_resolver.cc", "optional_debug_tools.cc", ], hdrs = [ @@ -122,6 +133,7 @@ cc_library( "interpreter.h", "model.h", "nnapi_delegate.h", + "op_resolver.h", "optional_debug_tools.h", ], copts = tflite_copts(), @@ -224,6 +236,18 @@ cc_test( ], ) +# Test OpResolver. +cc_test( + name = "op_resolver_test", + size = "small", + srcs = ["op_resolver_test.cc"], + deps = [ + ":framework", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + # Test the C extension API code. cc_test( name = "context_test", diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index e4f86e258afe3df9ba149c82066b6d145f332488..cc8a8035d1dadeec98886ba1dae4cdf403f26de4 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -29,7 +29,7 @@ GENDIR := $(MAKEFILE_DIR)/gen/obj/ CXX := $(CC_PREFIX)gcc CXXFLAGS := --std=c++11 -O3 -DNDEBUG CC := $(CC_PREFIX)gcc -CFLAGS := -O3 -DNDEBUG +CCFLAGS := -O3 -DNDEBUG LDOPTS := LDOPTS += -L/usr/local/lib ARFLAGS := -r diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 85216776823eab2ab3ac2a3bc666f21e312acc6c..612813caee880f3f7291ee9850f7d8f842d598a6 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -1,4 +1,8 @@ """Generate Flatbuffer binary from json.""" +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) def tflite_copts(): """Defines compile time flags.""" @@ -185,32 +189,109 @@ def json_to_tflite(name, src, out): tools = [flatc], ) -def gen_zipped_test_files(name, files): +# This is the master list of generated examples that will be made into tests. A +# function called make_XXX_tests() must also appear in generate_examples.py. +# Disable a test by commenting it out. If you do, add a link to a bug or issue. +def generated_test_models(): + return [ + "add", + "arg_max", + "avg_pool", + "batch_to_space_nd", + "concat", + "constant", + "control_dep", + "conv", + "depthwiseconv", + "div", + "equal", + "exp", + "expand_dims", + "floor", + "fully_connected", + "fused_batch_norm", + "gather", + "global_batch_norm", + "greater", + "greater_equal", + "l2norm", + "l2_pool", + "less", + "less_equal", + "local_response_norm", + "log_softmax", + "log", + "lstm", + "max_pool", + "maximum", + "mean", + "minimum", + "mul", + "neg", + "not_equal", + "pad", + "padv2", + # "prelu", + "relu", + "relu1", + "relu6", + "reshape", + "resize_bilinear", + "sigmoid", + "sin", + "slice", + "softmax", + "space_to_batch_nd", + "space_to_depth", + "sparse_to_dense", + "split", + "squeeze", + "strided_slice", + "strided_slice_1d_exhaustive", + "sub", + "tile", + "topk", + "transpose", + "transpose_conv", + "where", + ] + +def gen_zip_test(name, test_name, **kwargs): + """Generate a zipped-example test and its dependent zip files. + + Args: + name: Resulting cc_test target name + test_name: Test targets this model. Comes from the list above. + **kwargs: tf_cc_test kwargs. + """ + gen_zipped_test_file( + name = "zip_%s" % test_name, + file = "%s.zip" % test_name, + ) + tf_cc_test(name, **kwargs) + +def gen_zipped_test_file(name, file): """Generate a zip file of tests by using :generate_examples. Args: - name: Name of output. We will produce "`name`_files" as a target. - files: A list of zip file basenames. + name: Name of output. We will produce "`file`.files" as a target. + file: The name of one of the generated_examples targets, e.g. "transpose" """ toco = "//tensorflow/contrib/lite/toco:toco" - out_files = [] - for f in files: - out_file = name + "/" + f - out_files.append(out_file) - native.genrule( - name = name + "_" + f + ".files", - cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco - + " --zip_to_output " + f + " $(@D)"), - outs = [out_file], - tools = [ - ":generate_examples", - toco, - ], - ) + native.genrule( + name = file + ".files", + cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco + + " --zip_to_output " + file + " $(@D)"), + outs = [file], + tools = [ + ":generate_examples", + toco, + ], + ) native.filegroup( name = name, - srcs = out_files, + srcs = [file], ) def gen_selected_ops(name, model): diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 35cf43dd32b484f64e0db7d24a06691bdd0c830a..c1cc4476fbd45fa6b3f5b3a1ed2cba39cc2ad54b 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -148,10 +148,20 @@ typedef struct { float beta; } TfLiteLocalResponseNormParams; +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + typedef struct { + // Parameters for LSTM version 1. TfLiteFusedActivation activation; float cell_clip; float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; } TfLiteLSTMParams; typedef struct { @@ -230,6 +240,16 @@ typedef struct { TfLiteType output_type; } TfLiteArgMaxParams; +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; +} TfLiteTransposeConvParams; + +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index a038acf2848b21a225cbe9933cc8ae1f09739cee..aef9a92883f18dabfc36058507d739856c3c2af7 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by -// `schema_builtin_ops_header_generator.py`. +// `schema/builtin_ops_header/generator.cc`. #ifdef __cplusplus extern "C" { @@ -90,10 +90,18 @@ typedef enum { kTfLiteBuiltinGreaterEqual = 62, kTfLiteBuiltinLessEqual = 63, kTfLiteBuiltinSelect = 64, + kTfLiteBuiltinSlice = 65, + kTfLiteBuiltinSin = 66, + kTfLiteBuiltinTransposeConv = 67, + kTfLiteBuiltinSparseToDense = 68, + kTfLiteBuiltinTile = 69, + kTfLiteBuiltinExpandDims = 70, + kTfLiteBuiltinEqual = 71, + kTfLiteBuiltinNotEqual = 72, + kTfLiteBuiltinLog = 73, } TfLiteBuiltinOperator; #ifdef __cplusplus } // extern "C" #endif // __cplusplus #endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ -} diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 12841d233cc1d3c5e1219fc505b1975d2a7fa3e3..4eb66cc225eb04923be9aaa445a335ad822c8a6f 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -370,13 +370,21 @@ typedef struct _TfLiteRegistration { // Builtin codes. If this kernel refers to a builtin this is the code // of the builtin. This is so we can do marshaling to other frameworks like - // NN API. Note, it is the responsibility of the registration binder to - // set this properly. + // NN API. + // Note: It is the responsibility of the registration binder to set this + // properly. int32_t builtin_code; // Custom op name. If the op is a builtin, this will be null. + // Note: It is the responsibility of the registration binder to set this + // properly. // WARNING: This is an experimental interface that is subject to change. const char* custom_name; + + // The version of the op. + // Note: It is the responsibility of the registration binder to set this + // properly. + int version; } TfLiteRegistration; // WARNING: This is an experimental interface that is subject to change. diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h new file mode 100644 index 0000000000000000000000000000000000000000..abe802e34214caf4d5063da827b3aca4a82aa56d --- /dev/null +++ b/tensorflow/contrib/lite/context_util.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This provides a few C++ helpers that are useful for manipulating C structures +// in C++. +#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Provide a range iterable wrapper for TfLiteIntArray* (C lists that TfLite +// C api uses. Can't use the google array_view, since we can't depend on even +// absl for embedded device reasons. +class TfLiteIntArrayView { + public: + // Construct a view of a TfLiteIntArray*. Note, `int_array` should be non-null + // and this view does not take ownership of it. + explicit TfLiteIntArrayView(const TfLiteIntArray* int_array) + : int_array_(int_array) {} + + TfLiteIntArrayView(const TfLiteIntArrayView&) = default; + TfLiteIntArrayView& operator=(const TfLiteIntArrayView& rhs) = default; + + typedef const int* const_iterator; + const_iterator begin() const { return int_array_->data; } + const_iterator end() const { return &int_array_->data[int_array_->size]; } + size_t size() const { return end() - begin(); } + + private: + const TfLiteIntArray* int_array_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..35a8f6ca4166e373ea1a0af5d4a013327b30d2b6 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD @@ -0,0 +1,31 @@ +package(default_visibility = [ + "//visibility:public", +]) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "nnapi_delegate", + srcs = ["nnapi_delegate.cc"], + hdrs = ["nnapi_delegate.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + ], +) + +tf_cc_test( + name = "nnapi_delegate_test", + size = "small", + srcs = ["nnapi_delegate_test.cc"], + deps = [ + ":nnapi_delegate", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc new file mode 100644 index 0000000000000000000000000000000000000000..0731d14419d2dec2ea5efa48ef5d4b7728af635f --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -0,0 +1,464 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/builtin_ops.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" + +namespace tflite { +namespace { + +// TODO(b/80621585): Consider printing error string, but don't for now to +// minimize binary size. +#define CHECK_NN(context, code) \ + if (code != ANEURALNETWORKS_NO_ERROR) { \ + context->ReportError(context, "NN API returned error (%d).\n", code); \ + return kTfLiteError; \ + } + +// RAII NN API Model Destructor for use with std::unique_ptr +struct NNFreeModel { + void operator()(ANeuralNetworksModel* model) { + ANeuralNetworksModel_free(model); + } +}; +// RAII NN API Compilation Destructor for use with std::unique_ptr +struct NNFreeCompilation { + void operator()(ANeuralNetworksCompilation* model) { + ANeuralNetworksCompilation_free(model); + } +}; + +// Track tensor indices to NN API tensor indices mapping. +class OperandMapping { + public: + // Given a TFLite index return the ANN index. If it doesn't exist + // return -1. + int lite_index_to_ann(int index) const { + if (index < lite_tensor_to_ann_tensor_.size()) + return lite_tensor_to_ann_tensor_[index]; + else + return -1; + } + + // NN API uses non tensor operands instead of structs. This creates one + // and returns the index. It uses a std::vector and resizes it as needed + // keeping -1 to unmapped values. Intermediate tensors likely will not + // be mapped. + int add_new_non_tensor_operand() { return next_ann_tensor_index_++; } + + // Add a new mapping from `tflite_index` and return the NN API tensor index. + int add_new_ann_tensor_index(int tflite_index) { + if (tflite_index >= lite_tensor_to_ann_tensor_.size()) { + lite_tensor_to_ann_tensor_.resize(tflite_index + 1); + } + int new_tensor_index = next_ann_tensor_index_++; + lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index; + return new_tensor_index; + } + + private: + // Next index of ann tensor + int next_ann_tensor_index_ = 0; + + // Mapping from lite index. Use a std::vector for speed and code size + // rather than a map. + std::vector lite_tensor_to_ann_tensor_; +}; + +// Abstract builder for building an op in the NN API graph. This handles +// the disparity between TFLite and NN API operand types. NN API has singular +// operands for both tensors and parameters, and TFLite separates the two. +class NNAPIOpBuilder { + public: + NNAPIOpBuilder(TfLiteContext* context, OperandMapping* tensor_mapping, + ANeuralNetworksModel* nn_model) + : context_(context), + operand_mapping_(tensor_mapping), + nn_model_(nn_model) {} + + TfLiteStatus AddScalarInt32Operand(int value) { + ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, &value, sizeof(int32_t))); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + + TfLiteStatus AddTensorInput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_inputs_.push_back(ann_index); + return kTfLiteOk; + } + + TfLiteStatus AddTensorOutput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_outputs_.push_back(ann_index); + return kTfLiteOk; + } + + // Adds a new NN API tensor that shadows the TF Lite tensor `tensor_index`. + // This returns the NN API tensor index corresponding to the created tensor. + // If another caller previously created a NN API tensor for `tensor_index` + // then the existing one is returned. + TfLiteStatus AddTensor(int tensor_index, int* ann_tensor_index_out) { + int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index); + if (ann_tensor_index != -1) { + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + // Allocate a new tensor index + ann_tensor_index = operand_mapping_->add_new_ann_tensor_index(tensor_index); + + // Parameters needed for new type. + int32_t nn_type = 0; + float scale = 0.0f; + int32_t zeroPoint = 0; + TfLiteTensor* tensor = &context_->tensors[tensor_index]; + switch (tensor->type) { + case kTfLiteNoType: + // Tensors added during initialization of Ops don't have a type yet and + // should not be registered with the NNAPI. + *ann_tensor_index_out = -1; + return kTfLiteOk; + case kTfLiteFloat32: + nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; + scale = 0.f; + break; + case kTfLiteUInt8: + nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; + break; + case kTfLiteInt32: + nn_type = ANEURALNETWORKS_TENSOR_INT32; + scale = 0.f; + zeroPoint = 0; + break; + default: + context_->ReportError(context_, "Logic error in NN API Delegate.\n"); + return kTfLiteError; + } + + ANeuralNetworksOperandType operand_type{ + nn_type, static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), scale, zeroPoint}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + + if (tensor->allocation_type == kTfLiteMmapRo) { + // TODO(b/80630405): Use NNAPIAllocation. + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_tensor_index, tensor->data.raw, + tensor->bytes)); + } + + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + + // Finish emitting the op (of type `type`) into the NN API. + TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) { + // Actually add a NN API operation + CHECK_NN(context_, ANeuralNetworksModel_addOperation( + nn_model_, type, + static_cast(augmented_inputs_.size()), + augmented_inputs_.data(), + static_cast(augmented_outputs_.size()), + augmented_outputs_.data())); + augmented_outputs_.clear(); + augmented_outputs_.clear(); + return kTfLiteOk; + } + + private: + // TfLiteContext for error handling. Must be named context for macros to + // work. + TfLiteContext* context_; + + // Tracks relationship between indices + OperandMapping* operand_mapping_; + + // The model + ANeuralNetworksModel* nn_model_; + + // Inputs and outputs for the current op. These are augmented in the sense + // that NN API uses operands for all arguments, not just tensors, unlike + // TensorFlow lite. + std::vector augmented_inputs_; + std::vector augmented_outputs_; +}; + +// The kernel that represents the subgraph of TF Lite being run on NN API. +class NNAPIDelegateKernel { + public: + NNAPIDelegateKernel() = default; + + typedef ANeuralNetworksOperationType (*MappingFn)(TfLiteContext*, + NNAPIOpBuilder* builder, + TfLiteNode* node); + + // Return a function that knows how to translate a node into its operands + // when called. You can use this function to see if a node is supported + // (i.e. that MappingFn is not nullptr). + MappingFn Map(TfLiteContext* context, int builtin_code, TfLiteNode* node) { + switch (builtin_code) { + case kTfLiteBuiltinAdd: + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_ADD; + }; + break; + case kTfLiteBuiltinAveragePool2d: + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->filter_width); + builder->AddScalarInt32Operand(builtin->filter_height); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_AVERAGE_POOL_2D; + }; + break; + default: + return nullptr; + } + } + + // Initialize the kernel (a NN model). + TfLiteStatus Init(TfLiteContext* context, + const TfLiteDelegateParams* params) { + for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { + nodes_.push_back(node_index); + } + + if (!nn_model_) { + ANeuralNetworksModel* model; + CHECK_NN(context, ANeuralNetworksModel_create(&model)); + nn_model_.reset(model); + + TF_LITE_ENSURE_STATUS( + BuildGraph(context, params->input_tensors, params->output_tensors)); + } + + if (!nn_compilation_) { + ANeuralNetworksCompilation* compilation; + CHECK_NN(context, ANeuralNetworksCompilation_create(nn_model_.get(), + &compilation)); + CHECK_NN(context, ANeuralNetworksCompilation_finish(compilation)); + nn_compilation_.reset(compilation); + } + return kTfLiteOk; + } + + TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) { + ANeuralNetworksExecution* execution = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_create(nn_compilation_.get(), + &execution)); + + // Set the input tensor buffers. Note: we access tflite tensors using + // absolute indices but NN api indices inputs by relative indices. + int relative_input_index = 0; + for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { + TfLiteTensor* tensor = &context->tensors[absolute_input_index]; + CHECK_NN(context, ANeuralNetworksExecution_setInput( + execution, relative_input_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_input_index++; + } + + // Set the output tensor buffers. + int relative_output_index = 0; + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TfLiteTensor* tensor = &context->tensors[output_index]; + CHECK_NN(context, ANeuralNetworksExecution_setOutput( + execution, relative_output_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_output_index++; + } + // Invoke ANN in blocking fashion. + ANeuralNetworksEvent* event = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_startCompute(execution, &event)); + CHECK_NN(context, ANeuralNetworksEvent_wait(event)); + ANeuralNetworksEvent_free(event); + ANeuralNetworksExecution_free(execution); + + return kTfLiteOk; + } + + private: + // ANN API state. + std::unique_ptr nn_model_; + std::unique_ptr + nn_compilation_; + // Node indices that this delegate is responsible for. Indices here + // indexes into the nodes array in the TfLiteContext. + std::vector nodes_; + // Track indices we use + OperandMapping operand_mapping_; + + TfLiteStatus AddOpsAndTensors(TfLiteContext* context) { + // The operand builder allows creating a single op. We create it at this + // reduced power position rather than in the for loop to avoid reallocating + // the vectors. + NNAPIOpBuilder builder(context, &operand_mapping_, nn_model_.get()); + // Add Tensors + // allocate outside to avoid realloc + for (auto node_index : nodes_) { + // Obtain the op and registration. + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + // Map inputs to NN API tensor indices. + for (auto input_index : TfLiteIntArrayView(node->inputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + } + // Get op type and operands + int nn_op_type = + Map(context, reg->builtin_code, node)(context, &builder, node); + // Map outputs to NN API tensor indices. + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); + } + + builder.FinalizeAddOperation(nn_op_type); + } + return kTfLiteOk; + } + + TfLiteStatus BuildGraph(TfLiteContext* context, + const TfLiteIntArray* input_tensors, + const TfLiteIntArray* output_tensors) { + // Build the ops and tensors. + TF_LITE_ENSURE_STATUS(AddOpsAndTensors(context)); + // Map input and output tensor indices to ANN + std::vector inputs; + inputs.reserve(input_tensors->size); + std::vector outputs; + outputs.reserve(output_tensors->size); + // Make the TensorFlow lite inputs and outputs to ann_indices. + for (int i : TfLiteIntArrayView(input_tensors)) + inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + for (int i : TfLiteIntArrayView(output_tensors)) + outputs.push_back(operand_mapping_.lite_index_to_ann(i)); + // Tell ANN to declare inputs/outputs + CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs( + nn_model_.get(), inputs.size(), inputs.data(), + outputs.size(), outputs.data())); + // Finalize the model + CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get())); + + return kTfLiteOk; + } +}; + +} // namespace + +// Return a NN API Delegate struct that can check for support of ops. +TfLiteDelegate* NnApiDelegate() { + static TfLiteDelegate delegate = { + .data_ = nullptr, + .Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + // Do not check nodes_ if NN API is unavailable. + if (!NNAPIExists()) return kTfLiteOk; + + std::vector supported_nodes(1); + // We don't care about all nodes_, we only care about ones in the + // current plan. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + int total_supported_nodes = 0; + // Check for every node if it is supported + // TODO(b/80625235): Fix this to do more careful checking of versioning. + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + NNAPIDelegateKernel dummy_kernel; + if (dummy_kernel.Map(context, registration->builtin_code, node)) { + supported_nodes.push_back(node_index); + } + total_supported_nodes += 1; + } + // Put the size at the beginning of the array. + supported_nodes[0] = supported_nodes.size() - 1; + + // NN API Delegate Registration (the pseudo kernel that will invoke NN + // API subgraphs) + static const TfLiteRegistration nnapi_delegate_kernel = { + .init = [](TfLiteContext* context, const char* buffer, + size_t length) -> void* { + const TfLiteDelegateParams* params = + reinterpret_cast(buffer); + NNAPIDelegateKernel* kernel_state = new NNAPIDelegateKernel; + kernel_state->Init(context, params); + return kernel_state; + }, + + .free = [](TfLiteContext* context, void* buffer) -> void { + delete reinterpret_cast(buffer); + }, + + .prepare = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + // Since the underlying resize happened ahead of delegation + // worked. This does nothing. + return kTfLiteOk; + }, + + .invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + NNAPIDelegateKernel* state = + reinterpret_cast(node->user_data); + return state->Invoke(context, node); + }, + + .builtin_code = kTfLiteBuiltinDelegate, + }; + + // Request TFLite to partition the graph and make kernels + // for each independent subgraph a new nnapi_delegate_kernel. + context->ReplaceSubgraphsWithDelegateKernels( + context, nnapi_delegate_kernel, + reinterpret_cast(supported_nodes.data()), + delegate); + return kTfLiteOk; + }}; + + return &delegate; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h new file mode 100644 index 0000000000000000000000000000000000000000..44cca2fd285370d700525f98ba33c861fb97be1e --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Return a delegate that can be used to use the NN API. +// e.g. +// NnApiDelegate* delegate = NnApiDelegate(); +// interpreter->ModifyGraphWithDelegate(&delegate); +// NnApiDelegate() returns a singleton, so you should not free this +// pointer or worry about its lifetime. +TfLiteDelegate* NnApiDelegate(); +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff2e721423f07889f36746a2889afcc3369f28fc --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class FloatAddOpModel : public SingleOpModel { + public: + FloatAddOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +// Do a test with the NN API using no activation. +TEST(NNAPIDelegate, AddWithNoActivation) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +// Do a test with the NN api with relu. +TEST(NNAPIDelegate, AddWithRelu) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 436c3e1d4cad5e6ee355d7e9cf8ee7da1a8385ce..840015a7fad173dbd2ea353786871dd4e89abb98 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -30,9 +30,7 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once -# the archive has been propagated in mirror.bazel.build. -GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD index 9322e186a280e932a2441ab16ac8579d9ab67ee2..c61445114ecc6dfbe4f2b6ab666b28a8aa746be3 100644 --- a/tensorflow/contrib/lite/examples/label_image/BUILD +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -53,19 +53,18 @@ cc_library( ], ) -# TODO(ahentz): Test disabled as it has a memory leek from read_bmp -# cc_test( -# name = "label_image_test", -# srcs = [ -# "get_top_n.h", -# "get_top_n_impl.h", -# "label_image_test.cc", -# ], -# data = [ -# "testdata/grace_hopper.bmp", -# ], -# deps = [ -# ":bitmap_helpers", -# "//testing/base/public:gunit", -# ], -# ) +cc_test( + name = "label_image_test", + srcs = [ + "get_top_n.h", + "get_top_n_impl.h", + "label_image_test.cc", + ], + data = [ + "testdata/grace_hopper.bmp", + ], + deps = [ + ":bitmap_helpers", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc index 0b38cd38c83927c65d251b9356301b6bef7521f2..2735d1f5ea4e2a104f71a3a6f874d9acb2f48142 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc @@ -28,8 +28,9 @@ limitations under the License. namespace tflite { namespace label_image { -uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, - int width, int height, int channels, bool top_down) { +std::vector decode_bmp(const uint8_t* input, int row_size, int width, + int height, int channels, bool top_down) { + std::vector output(height * width * channels); for (int i = 0; i < height; i++) { int src_pos; int dst_pos; @@ -66,12 +67,11 @@ uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, } } } - return output; } -uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, - int* channels, Settings* s) { +std::vector read_bmp(const std::string& input_bmp_name, int* width, + int* height, int* channels, Settings* s) { int begin, end; std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary); @@ -87,14 +87,15 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, if (s->verbose) LOG(INFO) << "len: " << len << "\n"; - const uint8_t* img_bytes = new uint8_t[len]; + std::vector img_bytes(len); file.seekg(0, std::ios::beg); - file.read((char*)img_bytes, len); + file.read(reinterpret_cast(img_bytes.data()), len); const int32_t header_size = - *(reinterpret_cast(img_bytes + 10)); - *width = *(reinterpret_cast(img_bytes + 18)); - *height = *(reinterpret_cast(img_bytes + 22)); - const int32_t bpp = *(reinterpret_cast(img_bytes + 28)); + *(reinterpret_cast(img_bytes.data() + 10)); + *width = *(reinterpret_cast(img_bytes.data() + 18)); + *height = *(reinterpret_cast(img_bytes.data() + 22)); + const int32_t bpp = + *(reinterpret_cast(img_bytes.data() + 28)); *channels = bpp / 8; if (s->verbose) @@ -110,10 +111,9 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, bool top_down = (*height < 0); // Decode image, allocating tensor once the image size is known - uint8_t* output = new uint8_t[abs(*height) * *width * *channels]; const uint8_t* bmp_pixels = &img_bytes[header_size]; - return decode_bmp(bmp_pixels, row_size, output, *width, abs(*height), - *channels, top_down); + return decode_bmp(bmp_pixels, row_size, *width, abs(*height), *channels, + top_down); } } // namespace label_image diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h index 97343dde6b31694e5b2de20b35a7083fb8fe4a0e..5fc75b1f7274c14d49e4a26d6ce4902c037afa6b 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h @@ -22,8 +22,8 @@ limitations under the License. namespace tflite { namespace label_image { -uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, - int* channels, Settings* s); +std::vector read_bmp(const std::string& input_bmp_name, int* width, + int* height, int* channels, Settings* s); template void resize(T* out, uint8_t* in, int image_height, int image_width, diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index 2a64c1de725b601e9b6e9325d9faacb37df0e626..e36218e4f12057a362af47c48454f7930fc495f2 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -62,8 +62,8 @@ void resize(T* out, uint8_t* in, int image_height, int image_width, {1, wanted_height, wanted_width, wanted_channels}, quant); ops::builtin::BuiltinOpResolver resolver; - TfLiteRegistration* resize_op = - resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR); + const TfLiteRegistration* resize_op = + resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR, 1); auto* params = reinterpret_cast( malloc(sizeof(TfLiteResizeBilinearParams))); params->align_corners = false; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index 456c5c6dc782f4e21a5062e353635117a39cacb9..86d7d1cc4a625243791d5e7d5b746526a58efb6d 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -77,14 +77,13 @@ void PrintProfilingInfo(const profiling::ProfileEvent* e, uint32_t op_index, // time (ms) , Node xxx, OpCode xxx, symblic name // 5.352, Node 5, OpCode 4, DEPTHWISE_CONV_2D - LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3) << (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0 << ", Node " << std::setw(3) << std::setprecision(3) << op_index << ", OpCode " << std::setw(3) << std::setprecision(3) << registration.builtin_code << ", " << EnumNameBuiltinOperator( - (BuiltinOperator)registration.builtin_code) + static_cast(registration.builtin_code)) << "\n"; } @@ -139,8 +138,8 @@ void RunInference(Settings* s) { int image_width = 224; int image_height = 224; int image_channels = 3; - uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height, - &image_channels, s); + std::vector in = read_bmp(s->input_bmp_name, &image_width, + &image_height, &image_channels, s); int input = interpreter->inputs()[0]; if (s->verbose) LOG(INFO) << "input: " << input << "\n"; @@ -169,12 +168,12 @@ void RunInference(Settings* s) { switch (interpreter->tensor(input)->type) { case kTfLiteFloat32: s->input_floating = true; - resize(interpreter->typed_tensor(input), in, image_height, - image_width, image_channels, wanted_height, wanted_width, - wanted_channels, s); + resize(interpreter->typed_tensor(input), in.data(), + image_height, image_width, image_channels, wanted_height, + wanted_width, wanted_channels, s); break; case kTfLiteUInt8: - resize(interpreter->typed_tensor(input), in, + resize(interpreter->typed_tensor(input), in.data(), image_height, image_width, image_channels, wanted_height, wanted_width, wanted_channels, s); break; @@ -190,13 +189,13 @@ void RunInference(Settings* s) { if (s->profiling) profiler->StartProfiling(); struct timeval start_time, stop_time; - gettimeofday(&start_time, NULL); + gettimeofday(&start_time, nullptr); for (int i = 0; i < s->loop_count; i++) { if (interpreter->Invoke() != kTfLiteOk) { LOG(FATAL) << "Failed to invoke tflite!\n"; } } - gettimeofday(&stop_time, NULL); + gettimeofday(&stop_time, nullptr); LOG(INFO) << "invoked \n"; LOG(INFO) << "average time: " << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000) @@ -271,17 +270,17 @@ int Main(int argc, char** argv) { int c; while (1) { static struct option long_options[] = { - {"accelerated", required_argument, 0, 'a'}, - {"count", required_argument, 0, 'c'}, - {"verbose", required_argument, 0, 'v'}, - {"image", required_argument, 0, 'i'}, - {"labels", required_argument, 0, 'l'}, - {"tflite_model", required_argument, 0, 'm'}, - {"profiling", required_argument, 0, 'p'}, - {"threads", required_argument, 0, 't'}, - {"input_mean", required_argument, 0, 'b'}, - {"input_std", required_argument, 0, 's'}, - {0, 0, 0, 0}}; + {"accelerated", required_argument, nullptr, 'a'}, + {"count", required_argument, nullptr, 'c'}, + {"verbose", required_argument, nullptr, 'v'}, + {"image", required_argument, nullptr, 'i'}, + {"labels", required_argument, nullptr, 'l'}, + {"tflite_model", required_argument, nullptr, 'm'}, + {"profiling", required_argument, nullptr, 'p'}, + {"threads", required_argument, nullptr, 't'}, + {"input_mean", required_argument, nullptr, 'b'}, + {"input_std", required_argument, nullptr, 's'}, + {nullptr, 0, nullptr, 0}}; /* getopt_long stores the option index here. */ int option_index = 0; @@ -294,15 +293,14 @@ int Main(int argc, char** argv) { switch (c) { case 'a': - s.accel = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.accel = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'b': - s.input_mean = strtod(optarg, NULL); + s.input_mean = strtod(optarg, nullptr); break; case 'c': - s.loop_count = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.loop_count = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'i': s.input_bmp_name = optarg; @@ -314,19 +312,19 @@ int Main(int argc, char** argv) { s.model_name = optarg; break; case 'p': - s.profiling = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.profiling = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 's': - s.input_std = strtod(optarg, NULL); + s.input_std = strtod(optarg, nullptr); break; case 't': s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + optarg, nullptr, 10); break; case 'v': - s.verbose = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.verbose = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'h': case '?': diff --git a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc index ce35483f76e8f40ced79e1ee30774c62d0eba94e..de7de21f7741d3d46cb96e793e8bc4bfb21384fe 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc @@ -27,20 +27,20 @@ namespace label_image { TEST(LabelImageTest, GraceHopper) { std::string lena_file = - "tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp"; + "tensorflow/contrib/lite/examples/label_image/testdata/" + "grace_hopper.bmp"; int height, width, channels; Settings s; - uint8_t *data; - - data = read_bmp(lena_file, &width, &height, &channels, &s); + std::vector input = + read_bmp(lena_file, &width, &height, &channels, &s); ASSERT_EQ(height, 606); ASSERT_EQ(width, 517); ASSERT_EQ(channels, 3); - uint8_t *out = new uint8_t[606 * 517 * 3]; - downsize(out, data, 606, 517, 3, 214, 214, 3, &s); - ASSERT_EQ(out[0], 0x15); - ASSERT_EQ(out[214 * 214 * 3 - 1], 0x12); + std::vector output(606 * 517 * 3); + resize(output.data(), input.data(), 606, 517, 3, 214, 214, 3, &s); + ASSERT_EQ(output[0], 0x15); + ASSERT_EQ(output[214 * 214 * 3 - 1], 0x11); } TEST(LabelImageTest, GetTopN) { diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/contrib/lite/examples/minimal/minimal.cc index 106e3b027055b67092f653c6bcdc4827b56bdbaa..8b0ace96ccaf06ac1cbdc2ea95ac6e92ef886993 100644 --- a/tensorflow/contrib/lite/examples/minimal/minimal.cc +++ b/tensorflow/contrib/lite/examples/minimal/minimal.cc @@ -38,7 +38,7 @@ using namespace tflite; int main(int argc, char *argv[]) { if(argc != 2) { - fprintf(stderr, "Usage: %s \n"); + fprintf(stderr, "minimal \n"); return 1; } const char* filename = argv[1]; diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index d7cc854ebac08e79d346df0aca6e1fa56b490156..972e57f73e82961ebc5e341dd7a41bc00acc5d21 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -39,7 +39,7 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); int num_dims = NumDimensions(input); @@ -54,7 +54,7 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { using namespace tflite; - TfLiteTensor* input = GetInput(context, node,0); + const TfLiteTensor* input = GetInput(context, node,0); TfLiteTensor* output = GetOutput(context, node,0); float* input_data = input->data.f; diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md new file mode 100644 index 0000000000000000000000000000000000000000..bd2f797e6c5b05f52bec9fc34f1b8011aca70330 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md @@ -0,0 +1,206 @@ +# TensorFlow Lite Ops Versioning + +This document describes TensorFlow Lite's op versioning schema. Op +versioning enables developers to add new functionalities and parameters into +existing ops. In addition, it guarantees the following: + +* Backward compatibility: New TensorFlow Lite implementation should + handle an old model file. +* Forward compatibility: Old TensorFlow Lite implementation should + handle a new model file produced by new version of TOCO, as long as no new + features are used. +* Forward in-compatibility detection: If an old TensorFlow Lite implementation + reads a new model that contains a new version of an op which isn't + supported, it should report the error. + +## Example: Adding Dilation into Convolution + +The remainder of this document explains op versioning in TFLite by showing how +to add dilation parameters to the convolution operation. + +Knowledge of dilation is not required to understand this document. Note that: + +* 2 new integer parameters will be added: `dilation_width_factor` and + `dilation_height_factor`. +* Old convolution kernels that don't support dilation are equivalent to + setting the dilation factors to 1. + +### Change FlatBuffer Schema + +To add new parameters into an op, change the options table in +`lite/schema/schema.fbs`. + +For example, the options table of convolution looks like this: + +``` +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} +``` + +When adding new parameters: + +* Add comments indicating which parameters are supported by which version. +* When the new implementation gets the default values for newly added + parameters, it should work exactly the same as the old implementation. + +The table will be like this after the new parameters are added: + +``` +table Conv2DOptions { + // Parameters supported by version 1: + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + + // Parameters supported by version 2: + dilation_width_factor:int = 1; + dilation_height_factor:int = 1; +} +``` + +### Change C Structures and Kernel Implementation + +In TensorFlow Lite, the kernel implementation is decoupled from +FlatBuffer definition. The kernels read the parameter from C structures defined +in `lite/builtin_op_data.h`. + +The original convolution parameter is as follows: + +``` +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; +} TfLiteConvParams; +``` + +As with the FlatBuffer schema, add comments indicating which parameters are +supported starting from which version. The result is seen below: + +``` +typedef struct { + // Parameters supported by version 1: TfLitePadding padding; int + stride_width; + int stride_height; + TfLiteFusedActivation activation; + + // Parameters supported by version 2: + int dilation_width_factor; + int dilation_height_factor; +} TfLiteConvParams; +``` + +Please also change the kernel implementation to read the newly added parameters +from the C structures. The details are omitted here. + +### Change the FlatBuffer Reading Code + +The logic to read FlatBuffer and produce C structure is in `lite/model.cc`. + +Update the file to handle the new parameters, as shown below: + +``` +case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + params->dilation_width_factor = conv_params->dilation_width_factor(); + params->dilation_height_factor = conv_params->dilation_height_factor(); + } + *builtin_data = reinterpret_cast(params); + break; +} +``` + +It's not required to check the op version here. When the new implementation +reads an old model file where dilation factors are missing, it will use 1 as +the default value, and the new kernel will work consistently with the old +kernel. + +### Change Kernel Registration + +The MutableOpResolver (defined in `lite/op_resolver.h`) provides a few functions +to register op kernels. The minimum and maximum version are 1 by default: + +``` +void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); +void AddCustom(const char* name, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); +``` + +The built-in ops are registered in `lite/kernels/register.cc`. In this example, +we implemented a new op kernel which can handle `Conv2D` version 1 and 2, so we +need to change this line: + +``` +AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D()); +``` + +to: + +``` +AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), 1, 2); +``` + +### Change TOCO TFLite exporter + +The last step is to make TOCO populate the minimum version that's required to +execute the op. In this example, it means: + +* Populate version=1 when dilation factors are all 1. +* Populate version=2 otherwise. + +To do this, you need to override `GetVersion` function for the operator class in +`lite/toco/tflite/operator.cc`. + +For ops with only one version, the `GetVersion` function is defined as: + +``` +int GetVersion(const Operator& op) const override { return 1; } +``` + +When supporting multiple versions, check the parameters and determine the +version for the op, as shown in the following example: + +``` +int GetVersion(const Operator& op) const override { + const auto& conv_op = static_cast(op); + if (conv_op.dilation_width_factor != 1 || + conv_op.dilation_height_factor != 1) { + return 2; + } + return 1; +} +``` + +### Delegation Implementation + +TensorFlow Lite provides a delegation API which enables delegating ops to +hardware backends. In Delegate's `Prepare` function, check if the version +is supported for every node in Delegation code. + +``` +const int kMinVersion = 1; +TfLiteNode* node; +TfLiteRegistration; +context->GetNodeAndRegistration(context, node_index, &node, ®istration); + +if (registration->version > kMinVersion) { + // Reject the node if the version isn't supported. +} +``` + +This is required even if the delegation only supports version 1 ops, so the +delegation can detect incompatibility when getting a higher version op. + diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index f45fcceb2e615222ea9c14bf6da9fd0f7dc8c487..965273f0f04d33b52903c0551fff3533c31d3bd8 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -95,11 +95,7 @@ Here is a list of TensorFlow operations that are usually removed from the graph: * [tf.divide](https://www.tensorflow.org/api_docs/python/tf/divide) * [tf.fake_quant_with_min_max_args](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args) * [tf.fake_quant_with_min_max_vars](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars) -* [tf.greater](https://www.tensorflow.org/api_docs/python/tf/greater) -* [tf.greater_equal](https://www.tensorflow.org/api_docs/python/tf/greater_equal) * [tf.identity](https://www.tensorflow.org/api_docs/python/tf/identity) -* [tf.less](https://www.tensorflow.org/api_docs/python/tf/less) -* [tf.less_equal](https://www.tensorflow.org/api_docs/python/tf/less_equal) * [tf.maximum](https://www.tensorflow.org/api_docs/python/tf/maximum) * [tf.minimum](https://www.tensorflow.org/api_docs/python/tf/minimum) * [tf.multiply](https://www.tensorflow.org/api_docs/python/tf/multiply) @@ -132,9 +128,7 @@ TensorFlow operation not listed above are likely unsupported. Notably, the following common ops are not supported at the moment: * [tf.depth_to_space](https://www.tensorflow.org/api_docs/python/tf/depth_to_space) -* [tf.gather](https://www.tensorflow.org/api_docs/python/tf/gather) * [tf.image.resize_bilinear](https://www.tensorflow.org/api_docs/python/tf/image/resize_bilinear) -* [tf.slice](https://www.tensorflow.org/api_docs/python/tf/slice) * [tf.tanh](https://www.tensorflow.org/api_docs/python/tf/tanh) ## TensorFlow Lite Operations @@ -222,6 +216,23 @@ Options { } ``` +**CONV_2D_TRANSPOSE** + +``` +Inputs { + 0: output_shape + 1: filter + 2: 4D tensor +} +Outputs { + 0: the transpose (gradient) of conv2d +} +Options { + padding: SAME|VALID + stride_w,stride_h: stride of the filter window +} +``` + **DEPTHWISE_CONV_2D** ``` @@ -242,6 +253,19 @@ Options { } ``` +**EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + equal to the corresponding element of the second tensor. +} +``` + **EXP** ``` @@ -281,6 +305,19 @@ Options { } ``` +**GATHER** + +``` +Inputs { + 0: params tensor + 1: indices tensor + 2: axis tensor (optional) +} +Outputs { + 0: a tensor with same type as the params tensor. +} +``` + **GREATER** ``` @@ -392,6 +429,17 @@ Outputs { } ``` +**LOG** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: a tensor equivalent to log(input) +} +``` + **LOG_SOFTMAX** ``` @@ -475,6 +523,19 @@ Options { } ``` +**NOT_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is not + equal to the corresponding element of the second tensor. +} +``` + **RELU** ``` @@ -523,6 +584,19 @@ Options { } ``` +**SLICE** + +``` +Inputs { + 0: tensor + 1: 1D tensor + 2: 1D tensor +} +Outputs { + 0: slice of the input tensor of the given size from the given begin index. +} +``` + **SOFTMAX** ``` @@ -566,6 +640,21 @@ Outputs { } ``` +**SPARSE_TO_DENSE** + +``` +Inputs { + 0: 0D or 1D or 2D tensor + 1: 1D tensor + 2: 0D or 1D tensor + 3: 0D tensor + 4: a boolean value +} +Outputs { + 0: Dense Tensor of shape output_shape. Has the same type as sparse_values. +} +``` + **SPLIT** ``` @@ -608,7 +697,7 @@ Outputs { 0: slice of the input tensor of the given size } Options { - begin_mask: mask for begin indicies + begin_mask: mask for begin indices end_mask: mask for end indices shrink_axis_mask: mask that indicates which dimensions to remove } @@ -623,7 +712,7 @@ Inputs { } Outputs { 0: k largest element along each last dimensional slice - 1: indicies of values within the last dimension of the input ensor + 1: indices of values within the last dimension of the input ensor } ``` diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 0450e86ae7f84e4aa6c70235eb825ca3b4f7aebc..7315d8360680ca0d3c405dc80b593762275815ee 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -249,13 +249,20 @@ class Interpreter { return nullptr; } - // Return a pointer into the data of a given input tensor. The given index - // must be between 0 and inputs().size(). + // Return a mutable pointer into the data of a given input tensor. The given + // index must be between 0 and inputs().size(). template T* typed_input_tensor(int index) { return typed_tensor(inputs_[index]); } + // Return an immutable pointer into the data of a given input tensor. The + // given index must be between 0 and inputs().size(). + template + const T* typed_input_tensor(int index) const { + return typed_tensor(inputs_[index]); + } + // Return a mutable pointer into the data of a given output tensor. The given // index must be between 0 and outputs().size(). template diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 1e579226037fa360e4d5dad25077b8966e1126bc..593af81a18a1e20a41dcc8d9bb3a1d815876e294 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -1,7 +1,9 @@ # Description: # TensorFlow Lite Java API. -package(default_visibility = ["//visibility:private"]) +package(default_visibility = [ + "//tensorflow/contrib/lite/java/ovic:__pkg__", +]) licenses(["notice"]) # Apache 2.0 @@ -46,38 +48,6 @@ android_library( ], ) -android_library( - name = "ovicbenchmarkerlib", - srcs = [ - "ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java", - "ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", - ], - manifest = "AndroidManifest.xml", - visibility = ["//visibility:public"], - deps = [ - ":tensorflowlite", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", - "@org_checkerframework_qual", - ], -) - -java_library( - name = "ovicbenchmarkerlib_java", - srcs = [ - "ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java", - "ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", - ], - javacopts = JAVACOPTS, - visibility = ["//visibility:public"], - deps = [ - ":libtensorflowlite_jni.so", - ":tensorflowlite_java", - "//tensorflow/contrib/lite/java/src/main/native", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", - "@org_checkerframework_qual", - ], -) - java_library( name = "tensorflowlitelib", srcs = glob( @@ -180,24 +150,6 @@ java_test( ], ) -java_test( - name = "OvicClassifierTest", - size = "medium", - srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], - data = [ - "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", - "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", - ], - javacopts = JAVACOPTS, - test_class = "org.tensorflow.ovic.OvicClassifierTest", - visibility = ["//visibility:public"], - deps = [ - ":ovicbenchmarkerlib_java", - "@com_google_truth", - "@junit", - ], -) - filegroup( name = "libtensorflowlite_jni", srcs = select({ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml index ba63dce5d9a7192a2c3c4c5561333d39a3ecc024..95b6b7016f2818127a89d2e9212aa231a5ec24b9 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml @@ -31,6 +31,7 @@ android:theme="@style/MaterialTheme"> diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml index 72a229ecdb19f5309994e994d82e0b5b5ed617a2..ddb099a950c2f83d7b2867f8f35d96885229536d 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml @@ -28,7 +28,7 @@ + - + android:id="@+id/bottom_info_view" + android:layout_marginBottom="10dp" + android:layout_height="50dp"> + + + android:layout_marginLeft="10dp" + android:background="#0000000f" + android:textColor="@android:color/white" /> + + - - diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml index 72a229ecdb19f5309994e994d82e0b5b5ed617a2..e567009a424ed77384bee193c47d4f4d253f5767 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml @@ -28,7 +28,7 @@ + - + android:id="@+id/bottom_info_view" + android:layout_marginBottom="10dp" + android:layout_height="50dp"> + + + android:layout_marginLeft="10dp" + android:background="#0000000f" + android:textColor="@android:color/white" /> - - + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml index 0a71dbd0e8010f5e3a176de1f7e8321331289f7c..7af8f3a98c6319da7723928ce61802ed4c5497ec 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml @@ -16,7 +16,7 @@ --> - TfLiteCameraDemo + TfLite Camera Demo + Threads: diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml index 3f3bdfb49480e779c108cd15da854ae82a118d52..1752b3b5f97e288d8b59106dfece1d84fe21d0ba 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml @@ -14,5 +14,10 @@ limitations under the License. --> - + diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..362d93636f72205ddcda6d97fa9fae376ff211f1 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -0,0 +1,68 @@ +# Description: +# OVIC Benchmarker Java API. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") + +java_test( + name = "OvicClassifierTest", + size = "medium", + srcs = ["src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.ovic.OvicClassifierTest", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", + "@com_google_truth", + "@junit", + ], +) + +java_binary( + name = "ovic_validator", + srcs = ["src/main/java/org/tensorflow/ovic/OvicValidator.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + ], + main_class = "org.tensorflow.ovic.OvicValidator", + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", + ], +) + +android_library( + name = "ovicbenchmarkerlib", + srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassifier.java", + "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + ], + manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) + +java_library( + name = "ovicbenchmarkerlib_java", + srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassifier.java", + "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + ], + javacopts = JAVACOPTS, + deps = [ + "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so", + "//tensorflow/contrib/lite/java:tensorflowlite_java", + "//tensorflow/contrib/lite/java/src/main/native", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md index 77799b35691813868fb65a2c8b068f41751717db..26349347faebac135ae555e0c5d8219046ab1c29 100644 --- a/tensorflow/contrib/lite/java/ovic/README.md +++ b/tensorflow/contrib/lite/java/ovic/README.md @@ -2,7 +2,7 @@ This folder contains building code for track one of the [Low Power ImageNet Recognition Challenge workshop at CVPR 2018.](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018) -## Pre-requesits +## Pre-requisite Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK. @@ -37,19 +37,37 @@ unzip -j /tmp/ovic.zip -d tensorflow/contrib/lite/java/ovic/src/testdata/ You can run test with Bazel as below. This helps to ensure that the installation is correct. ```sh -bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java:OvicClassifierTest --cxxopt=-Wno-all --test_output=all +bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:OvicClassifierTest --cxxopt=-Wno-all --test_output=all ``` ### Test your submissions -Once you have a submission that follows the instructions from the [competition site](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018), you can verify it as below. +Once you have a submission that follows the instructions from the [competition site](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018), you can verify it in two ways: -* Move your submission to the testdata folder: +#### Validate using randomly generated images + +You can call the validator binary below to verify that your model fits the format requirements. This often helps you to catch size mismatches (e.g. output should be [1, 1001] instead of [1,1,1,1001]). Let say the submission file is located at `/path/to/my_model.lite`, then call: + +```sh +bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:ovic_validator --cxxopt=-Wno-all +bazel-bin/tensorflow/contrib/lite/java/ovic/ovic_validator /path/to/my_model.lite +``` + +Successful validation should print the following message to terminal: + +``` +Successfully validated /path/to/my_model.lite. + +``` + +#### Test that the model produces sensible outcomes -Let say the submission file is located at `/tmp/my_model.lite`, then +You can go a step further to verify that the model produces results as expected. This helps you catch bugs during TOCO conversion (e.g. using the wrong mean and std values). + +* Move your submission to the testdata folder: ```sh -cp /tmp/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ +cp /path/to/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ ``` * Resize the test image to the resolutions that are expected by your submission: @@ -136,3 +154,5 @@ Note: the benchmarking results can be quite different depending on the backgroun | quantized_model.lite | 85 | 74 | | low_res_model.lite | 4.2 | 4.0 | +Since Pixel 2 has excellent support for 8-bit quantized models, we strongly recommend you to check out the [quantization training tutorial](https://www.tensorflow.org/performance/quantization). + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD index 47101ff574a797a81c5d993b0863c024885f03a0..83974f4b337baedebaf9c9ffc0a03501418a3e36 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -21,8 +21,8 @@ android_binary( resource_files = glob(["res/**"]), tags = ["manual"], deps = [ - "//tensorflow/contrib/lite/java:ovicbenchmarkerlib", "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib", "@androidsdk//com.android.support:support-v13-25.2.0", "@androidsdk//com.android.support:support-v4-25.2.0", ], diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java new file mode 100644 index 0000000000000000000000000000000000000000..a504ec74a9d0a124f877a6377cae155f204849a5 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java @@ -0,0 +1,94 @@ +/*Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.ovic; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.Random; + +/** Validate a submission model. */ +public class OvicValidator { + private static void printUsage(PrintStream s) { + s.println("Java program that validates a submission model."); + s.println(); + s.println("Usage: ovic_validator "); + s.println(); + s.println("Where:"); + s.println(" is the model in TfLite format;"); + } + + public static void main(String[] args) { + if (args.length != 1) { + printUsage(System.err); + System.exit(1); + } + final String labelPath = + "tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt"; + + final String modelFile = args[0]; + try { + File labelsfile = new File(labelPath); + InputStream labelsInputStream = new FileInputStream(labelsfile); + MappedByteBuffer model = loadModelFile(modelFile); + OvicClassifier classifier = new OvicClassifier(labelsInputStream, model); + ByteBuffer imgData = createByteBufferForClassifier(classifier); + OvicSingleImageResult testResult = classifier.classifyByteBuffer(imgData); + if (testResult.topKClasses.isEmpty()) { + throw new RuntimeException("Failed to return top K predictions."); + } + System.out.printf("Successfully validated %s.%n", modelFile); + } catch (Exception e) { + System.out.println(e.getMessage()); + System.out.printf("Failed to validate %s.%n", modelFile); + } + } + + private static ByteBuffer createByteBufferForClassifier(OvicClassifier classifier) { + if (classifier == null) { + throw new RuntimeException("Cannot create image buffer with the classifier."); + } + int[] inputDims = classifier.getInputDims(); + int imgHeight = inputDims[1]; + int imgWidth = inputDims[2]; + ByteBuffer imgData = ByteBuffer.allocateDirect(imgHeight * imgWidth * 3); + imgData.order(ByteOrder.nativeOrder()); + Random rand = new Random(); + for (int y = 0; y < imgHeight; y++) { + for (int x = 0; x < imgWidth; x++) { + int val = rand.nextInt(); + imgData.put((byte) ((val >> 16) & 0xFF)); + imgData.put((byte) ((val >> 8) & 0xFF)); + imgData.put((byte) (val & 0xFF)); + } + } + return imgData; + } + + private static MappedByteBuffer loadModelFile(String modelFilePath) throws IOException { + File modelfile = new File(modelFilePath); + FileInputStream inputStream = new FileInputStream(modelfile); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = 0L; + long declaredLength = fileChannel.size(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index e84ee7112983ec584308b7cbcd919f119eccbcc9..fd1f0ffa68eeca7b5866b146ecaa1f9216ef377d 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow.lite; import java.io.File; +import java.nio.ByteBuffer; import java.nio.MappedByteBuffer; import java.util.HashMap; import java.util.Map; @@ -80,6 +81,29 @@ public final class Interpreter implements AutoCloseable { wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), numThreads); } + /** + * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file. + * + *

The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The + * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a + * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. + */ + public Interpreter(@NonNull ByteBuffer byteBuffer) { + wrapper = new NativeInterpreterWrapper(byteBuffer); + } + + /** + * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and specifies the + * number of threads used for inference. + * + *

The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The + * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a + * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. + */ + public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) { + wrapper = new NativeInterpreterWrapper(byteBuffer, numThreads); + } + /** * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. * @@ -215,11 +239,11 @@ public final class Interpreter implements AutoCloseable { } } - public void setNumThreads(int num_threads) { + public void setNumThreads(int numThreads) { if (wrapper == null) { throw new IllegalStateException("The interpreter has already been closed."); } - wrapper.setNumThreads(num_threads); + wrapper.setNumThreads(numThreads); } /** Release resources associated with the {@code Interpreter}. */ @@ -229,5 +253,14 @@ public final class Interpreter implements AutoCloseable { wrapper = null; } + @Override + protected void finalize() throws Throwable { + try { + close(); + } finally { + super.finalize(); + } + } + NativeInterpreterWrapper wrapper; } diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index a43251cad13a4ed0b35367e796948b4b9a9faa67..2ae6c516b03ef4292667bbd944c73d2eeaf82db3 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -43,21 +43,31 @@ final class NativeInterpreterWrapper implements AutoCloseable { } /** - * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer}. The - * MappedByteBuffer should not be modified after the construction of a {@code - * NativeInterpreterWrapper}. + * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer}. The ByteBuffer should + * not be modified after the construction of a {@code NativeInterpreterWrapper}. The {@code + * ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a direct + * {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model. */ - NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) { - this(mappedByteBuffer, /* numThreads= */ -1); + NativeInterpreterWrapper(ByteBuffer byteBuffer) { + this(byteBuffer, /* numThreads= */ -1); } /** - * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer} and specifies - * the number of inference threads. The MappedByteBuffer should not be modified after the - * construction of a {@code NativeInterpreterWrapper}. + * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer} and specifies the + * number of inference threads. The ByteBuffer should not be modified after the construction of a + * {@code NativeInterpreterWrapper}. The {@code ByteBuffer} can be either a {@code + * MappedByteBuffer} that memory-maps a model file, or a direct {@code ByteBuffer} of + * nativeOrder() that contains the bytes content of a model. */ - NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer, int numThreads) { - modelByteBuffer = mappedByteBuffer; + NativeInterpreterWrapper(ByteBuffer buffer, int numThreads) { + if (buffer == null + || (!(buffer instanceof MappedByteBuffer) + && (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) { + throw new IllegalArgumentException( + "Model ByteBuffer should be either a MappedByteBuffer of the model file, or a direct " + + "ByteBuffer using ByteOrder.nativeOrder() which contains bytes of model content."); + } + modelByteBuffer = buffer; errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); @@ -90,9 +100,10 @@ final class NativeInterpreterWrapper implements AutoCloseable { dataTypes[i] = dataType.getNumber(); if (dataType == DataType.BYTEBUFFER) { ByteBuffer buffer = (ByteBuffer) inputs[i]; - if (buffer.order() != ByteOrder.nativeOrder()) { + if (buffer == null || !buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()) { throw new IllegalArgumentException( - "Input error: ByteBuffer shoud use ByteOrder.nativeOrder()."); + "Input error: ByteBuffer should be a direct ByteBuffer that uses " + + "ByteOrder.nativeOrder()."); } numsOfBytes[i] = buffer.limit(); sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]); @@ -153,8 +164,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { useNNAPI(interpreterHandle, useNNAPI); } - void setNumThreads(int num_threads) { - numThreads(interpreterHandle, num_threads); + void setNumThreads(int numThreads) { + numThreads(interpreterHandle, numThreads); } /** Gets index of an input given its name. */ @@ -314,7 +325,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { private long inferenceDurationNanoseconds = -1; - private MappedByteBuffer modelByteBuffer; + private ByteBuffer modelByteBuffer; private Map inputsIndexes; @@ -328,13 +339,13 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native void useNNAPI(long interpreterHandle, boolean state); - private static native void numThreads(long interpreterHandle, int num_threads); + private static native void numThreads(long interpreterHandle, int numThreads); private static native long createErrorReporter(int size); private static native long createModel(String modelPathOrBuffer, long errorHandle); - private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle); + private static native long createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle); private static native long createInterpreter(long modelHandle, long errorHandle, int numThreads); diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 45f510da1d940a288e2794cb3e08f66451956b64..1fb6997fb9ba180e9a3f3a89a6d177086440c0d7 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -387,7 +387,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( jlong capacity = env->GetDirectBufferCapacity(model_buffer); if (!VerifyModel(buf, capacity)) { throwException(env, kIllegalArgumentException, - "MappedByteBuffer is not a valid flatbuffer model"); + "ByteBuffer is not a valid flatbuffer model"); return 0; } @@ -395,8 +395,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( buf, static_cast(capacity), error_reporter); if (!model) { throwException(env, kIllegalArgumentException, - "MappedByteBuffer does not encode a valid " - "TensorFlowLite model: %s", + "ByteBuffer does not encode a valid model: %s", error_reporter->CachedErrorMessage()); return 0; } @@ -426,7 +425,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( status = interpreter->AllocateTensors(); if (status != kTfLiteOk) { throwException(env, kNullPointerException, - "Internal error: Cannot allocate memory for the interpreter", + "Internal error: Cannot allocate memory for the interpreter:" + " %s", error_reporter->CachedErrorMessage()); return 0; } diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc index 005dca0253d2c30d56a15adf6e2b371d43f50945..9e9387da86ebde7d711a7ce967461e370c95bc3e 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -43,31 +43,27 @@ size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, } switch (type) { case kTfLiteFloat32: { - jfloatArray a = static_cast(array); - jfloat* values = env->GetFloatArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseFloatArrayElements(a, values, JNI_ABORT); + jfloatArray float_array = static_cast(array); + jfloat* float_dst = static_cast(dst); + env->GetFloatArrayRegion(float_array, 0, num_elements, float_dst); return to_copy; } case kTfLiteInt32: { - jintArray a = static_cast(array); - jint* values = env->GetIntArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseIntArrayElements(a, values, JNI_ABORT); + jintArray int_array = static_cast(array); + jint* int_dst = static_cast(dst); + env->GetIntArrayRegion(int_array, 0, num_elements, int_dst); return to_copy; } case kTfLiteInt64: { - jlongArray a = static_cast(array); - jlong* values = env->GetLongArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseLongArrayElements(a, values, JNI_ABORT); + jlongArray long_array = static_cast(array); + jlong* long_dst = static_cast(dst); + env->GetLongArrayRegion(long_array, 0, num_elements, long_dst); return to_copy; } case kTfLiteUInt8: { - jbyteArray a = static_cast(array); - jbyte* values = env->GetByteArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseByteArrayElements(a, values, JNI_ABORT); + jbyteArray byte_array = static_cast(array); + jbyte* byte_dst = static_cast(dst); + env->GetByteArrayRegion(byte_array, 0, num_elements, byte_dst); return to_copy; } default: { diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 210d9437241f117ab281b627a4352fce7d340bcb..82007a6ab5be3492495125b1c20ed155907ae5a0 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -19,6 +19,8 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; import java.io.File; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.file.Files; @@ -69,6 +71,49 @@ public final class InterpreterTest { fileChannel.close(); } + @Test + public void testRunWithDirectByteBufferModel() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) fileChannel.size()); + byteBuffer.order(ByteOrder.nativeOrder()); + fileChannel.read(byteBuffer); + Interpreter interpreter = new Interpreter(byteBuffer); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + fileChannel.close(); + } + + @Test + public void testRunWithInvalidByteBufferModel() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + ByteBuffer byteBuffer = ByteBuffer.allocate((int) fileChannel.size()); + byteBuffer.order(ByteOrder.nativeOrder()); + fileChannel.read(byteBuffer); + try { + Interpreter interpreter = new Interpreter(byteBuffer); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Model ByteBuffer should be either a MappedByteBuffer" + + " of the model file, or a direct ByteBuffer using ByteOrder.nativeOrder()"); + } + fileChannel.close(); + } + @Test public void testRun() { Interpreter interpreter = new Interpreter(MODEL_FILE); diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 79e3c9f2664594c51b2a0cdf6b7a24ee7baa5bec..cf5d0b4ce9cb3c516c185f31fea12db70a2c3bdb 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -143,9 +143,11 @@ cc_library( "depthwise_conv.cc", "dequantize.cc", "div.cc", + "elementwise.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", "exp.cc", + "expand_dims.cc", "floor.cc", "fully_connected.cc", "gather.cc", @@ -166,15 +168,19 @@ cc_library( "resize_bilinear.cc", "select.cc", "skip_gram.cc", + "slice.cc", "space_to_batch_nd.cc", "space_to_depth.cc", + "sparse_to_dense.cc", "split.cc", "squeeze.cc", "strided_slice.cc", "sub.cc", "svdf.cc", + "tile.cc", "topk_v2.cc", "transpose.cc", + "transpose_conv.cc", "unidirectional_sequence_lstm.cc", "unidirectional_sequence_rnn.cc", ], @@ -454,6 +460,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "elementwise_test", + size = "small", + srcs = ["elementwise_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "unidirectional_sequence_lstm_test", size = "small", @@ -841,6 +860,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "tile_test", + size = "small", + srcs = ["tile_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "comparisons_test", size = "small", @@ -888,6 +921,64 @@ tf_cc_test( ], ) +tf_cc_test( + name = "slice_test", + size = "small", + srcs = [ + "slice_test.cc", + ], + tags = [ + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "transpose_conv_test", + size = "small", + srcs = ["transpose_conv_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "expand_dims_test", + size = "small", + srcs = ["expand_dims_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "sparse_to_dense_test", + size = "small", + srcs = ["sparse_to_dense_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 39a54c93962b33f3a787b3387d9a133119d0e80a..add36b46c0b8a4deab1e842d50194c8b99a3a20c 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -55,7 +55,7 @@ void Free(TfLiteContext* context, void* buffer) { TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); @@ -68,7 +68,7 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); @@ -95,7 +95,7 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); @@ -126,7 +126,7 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); @@ -153,9 +153,9 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); - TfLiteTensor* alpha = GetInput(context, node, 1); + const TfLiteTensor* alpha = GetInput(context, node, 1); output->type = input->type; @@ -179,7 +179,7 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { case kTfLiteFloat32: { @@ -191,13 +191,14 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } break; default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } } TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { case kTfLiteFloat32: { @@ -211,13 +212,14 @@ TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } break; default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } } TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { case kTfLiteFloat32: { @@ -229,14 +231,15 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } break; default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } } TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { OpData* data = reinterpret_cast(node->user_data); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { case kTfLiteFloat32: { @@ -256,7 +259,8 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } break; default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } } @@ -265,7 +269,7 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { OpData* data = reinterpret_cast(node->user_data); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { case kTfLiteFloat32: { @@ -285,14 +289,15 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { break; } default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } return kTfLiteOk; } // Takes a 2D tensor and perform softmax along the second dimension. -void Softmax2DFloat(TfLiteTensor* input, TfLiteTensor* output, +void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params) { const int batch_size = input->dims->data[0]; const int input_size = input->dims->data[1]; @@ -327,7 +332,7 @@ void Softmax2DFloat(TfLiteTensor* input, TfLiteTensor* output, } } -void Softmax2DQuantized(TfLiteTensor* input, TfLiteTensor* output, +void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params, OpData* data) { // TODO(ahentz): this is arguably a dirty trick. Since the implementation // always traverses the last dimension of a 4D tensor, we will pretend our 2D @@ -343,14 +348,14 @@ void Softmax2DQuantized(TfLiteTensor* input, TfLiteTensor* output, } // Takes a 4D tensor and perform softmax along the forth dimension. -void Softmax4DFloat(TfLiteTensor* input, TfLiteTensor* output, +void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params) { optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), params->beta, GetTensorData(output), GetTensorDims(output)); } -void Softmax4DQuantized(TfLiteTensor* input, TfLiteTensor* output, +void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params, OpData* data) { optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), data->input_multiplier, data->input_left_shift, @@ -362,7 +367,7 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); // TODO(ahentz): consider an implementation that works for many (all?) @@ -377,8 +382,9 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { Softmax4DFloat(input, output, params); return kTfLiteOk; } - context->ReportError(context, - "Only 2D and 4D tensors supported currently."); + context->ReportError( + context, "Only 2D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); return kTfLiteError; } case kTfLiteUInt8: { @@ -390,19 +396,21 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { Softmax4DQuantized(input, output, params, data); return kTfLiteOk; } - context->ReportError(context, - "Only 2D and 4D tensors supported currently."); + context->ReportError( + context, "Only 2D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); return kTfLiteError; } default: - context->ReportError(context, - "Only float32 and uint8_t supported currently."); + context->ReportError( + context, "Only float32 and uint8_t supported currently, got %d.", + input->type); return kTfLiteError; } } TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { case kTfLiteFloat32: @@ -411,18 +419,20 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { GetTensorData(output), GetTensorDims(output)); return kTfLiteOk; default: - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently., got %d", + input->type); return kTfLiteError; } } TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, 0); - TfLiteTensor* alpha = GetInput(context, node, 1); - TfLiteTensor* output = GetOutput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* alpha = GetInput(context, node, 1); + const TfLiteTensor* output = GetOutput(context, node, 0); if (input->type != kTfLiteFloat32) { - context->ReportError(context, "Only float32 supported currently."); + context->ReportError(context, "Only float32 supported currently, got %d.", + input->type); return kTfLiteError; } TF_LITE_ENSURE_EQ(context, input->dims->size, 4); diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index e0aa070e2d02cecb9c6ff500ab32b8ad2159db6e..7ca1e35489cba3b5d2567bc04e532fedf8a527a7 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -57,8 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, input1->type, input2->type); @@ -80,7 +80,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { template void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, const OpData* data, - TfLiteTensor* input1, TfLiteTensor* input2, + const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, @@ -109,7 +109,7 @@ void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, template void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, const OpData* data, - TfLiteTensor* input1, TfLiteTensor* input2, + const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { auto input1_offset = -input1->params.zero_point; auto input2_offset = -input2->params.zero_point; @@ -164,8 +164,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { diff --git a/tensorflow/contrib/lite/kernels/arg_max.cc b/tensorflow/contrib/lite/kernels/arg_max.cc index a2c5e4ceadbc905d22eb02b450c88745a351f58f..26f57e88962116f446e72fbc164d2747e8b633b4 100644 --- a/tensorflow/contrib/lite/kernels/arg_max.cc +++ b/tensorflow/contrib/lite/kernels/arg_max.cc @@ -33,8 +33,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* axis = GetInput(context, node, kAxis); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* axis = GetInput(context, node, kAxis); // Make sure the axis is only 1 dimension. TF_LITE_ENSURE_EQ(context, NumElements(axis), 1); @@ -52,7 +52,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output->type = kTfLiteInt64; break; default: - context->ReportError(context, "Unknown index output data type"); + context->ReportError(context, "Unknown index output data type: %d", + params->output_type); return kTfLiteError; } @@ -64,7 +65,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { break; default: - context->ReportError(context, "Only float32 and int types are supported"); + context->ReportError( + context, + "Unkonwn input type: %d, only float32 and int types are supported", + input->type); return kTfLiteError; } @@ -79,12 +83,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // The current impl actually ignores the axis argument. // Only determine the index of the maximum value in the last dimension. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* axis = GetInput(context, node, kAxis); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* axis = GetInput(context, node, kAxis); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); #define TF_LITE_ARG_MAX(data_type, axis_type, output_type) \ - TF_LITE_ENSURE_EQ(context, GetTensorData(axis)[0], 3); \ optimized_ops::ArgMax(GetTensorData(axis), \ GetTensorData(input), GetTensorDims(input), \ GetTensorData(output), \ diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc index 602f3888c10b3790dc0328c817bdd83276544b25..91d8dd3fa71b4f2ac70c64c4923c5240b61a2b25 100644 --- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc +++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc @@ -72,7 +72,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); @@ -102,7 +102,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->user_data); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size, diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index 2c5074eca3176c7f33a6f051b492dc41333257ed..c09b15b3d263d6cd639234590c99a50a9a48f4a7 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -12,18 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include -#include -#include -#include -#include +#include +#include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -35,20 +31,29 @@ constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kRecurrentWeightsTensor = 2; constexpr int kBiasTensor = 3; -constexpr int KHiddenStateTensor = 0; +constexpr int kHiddenStateTensor = 0; constexpr int kOutputTensor = 1; +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* scratch_tensor_index = new int; + context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; - TfLiteTensor* input_weights = - &context->tensors[node->inputs->data[kWeightsTensor]]; - TfLiteTensor* recurrent_weights = - &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; - TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* recurrent_weights = + GetInput(context, node, kRecurrentWeightsTensor); + const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); // Check all the parameters of tensor match within themselves and match the // input configuration. @@ -58,10 +63,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type); - TfLiteTensor* hidden_state = - &context->tensors[node->outputs->data[KHiddenStateTensor]]; - TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Resize state. TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); @@ -80,25 +86,54 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size_array)); + // Allocate temporary tensors to store quantized values of input and + // hidden_state tensors. + if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) { + int* scratch_tensor_index = reinterpret_cast(node->user_data); + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(3); + node->temporaries->data[0] = *scratch_tensor_index; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[1] = *scratch_tensor_index + 1; + TfLiteTensor* hidden_state_quantized = + GetTemporary(context, node, /*index=*/1); + hidden_state_quantized->type = kTfLiteUInt8; + hidden_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(hidden_state_quantized->dims, + hidden_state->dims)) { + TfLiteIntArray* hidden_state_quantized_size = + TfLiteIntArrayCopy(hidden_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, hidden_state_quantized, + hidden_state_quantized_size)); + } + node->temporaries->data[2] = *scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + } + return kTfLiteOk; } -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; - TfLiteTensor* input_weights = - &context->tensors[node->inputs->data[kWeightsTensor]]; - TfLiteTensor* recurrent_weights = - &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; - TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; - TfLiteTensor* hidden_state = - &context->tensors[node->outputs->data[KHiddenStateTensor]]; - TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; - - // Initialize the pointer bias. - const float* bias_ptr = bias->data.f; - +TfLiteStatus EvalFloat(const TfLiteTensor* input, + const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, + const TfLiteTensor* bias, const TfLiteRNNParams* params, + TfLiteTensor* hidden_state, TfLiteTensor* output) { const int batch_size = input->dims->data[0]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[1]; @@ -108,9 +143,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Initialize the pointer to input and output. const float* input_ptr_batch = input->data.f; float* output_ptr_batch = output->data.f; - // Initialize input_weights and recurrent_weights. + // Initialize input_weights, recurrent_weights and bias. const float* input_weights_ptr = input_weights->data.f; const float* recurrent_weights_ptr = recurrent_weights->data.f; + const float* bias_ptr = bias->data.f; kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr, input_size, @@ -119,11 +155,85 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid(const TfLiteTensor* input, + const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, + const TfLiteTensor* bias, const TfLiteRNNParams* params, + TfLiteTensor* input_scratch, + TfLiteTensor* hidden_state_scratch, + TfLiteTensor* scaling_factors, + TfLiteTensor* hidden_state, TfLiteTensor* output) { + const int batch_size = input->dims->data[0]; + const int num_units = input_weights->dims->data[0]; + const int input_size = input->dims->data[1]; + + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f; + // Initialize the pointer to input and output. + const float* input_ptr_batch = input->data.f; + float* output_ptr_batch = output->data.f; + // Initialize input_weights, recurrent_weights and bias. + const int8_t* input_weights_ptr = + reinterpret_cast(input_weights->data.uint8); + const int8_t* recurrent_weights_ptr = + reinterpret_cast(recurrent_weights->data.uint8); + const float* bias_ptr = bias->data.f; + // Get the scale of the quantized weights. + float input_weights_scale = input_weights->params.scale; + float recurrent_weights_scale = recurrent_weights->params.scale; + // Initialize temporary storage for quantized values. + int8_t* quantized_input_ptr = + reinterpret_cast(input_scratch->data.uint8); + int8_t* quantized_hidden_state_ptr = + reinterpret_cast(hidden_state_scratch->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + + kernel_utils::RnnBatchStep( + input_ptr_batch, input_weights_ptr, input_weights_scale, + recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, + num_units, batch_size, params->activation, quantized_input_ptr, + quantized_hidden_state_ptr, scaling_factors_ptr, hidden_state_ptr_batch, + output_ptr_batch); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* recurrent_weights = + GetInput(context, node, kRecurrentWeightsTensor); + const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); + TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // We already checked that weight types are consistent, so branch on one. + switch (input_weights->type) { + case kTfLiteFloat32: + return EvalFloat(input, input_weights, recurrent_weights, bias, params, + hidden_state, output); + case kTfLiteUInt8: { + // TODO(mirkov): implement eval with quantized inputs as well. + TfLiteTensor* input_quantized = GetTemporary(context, node, 0); + TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); + TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + return EvalHybrid(input, input_weights, recurrent_weights, bias, params, + input_quantized, hidden_state_quantized, + scaling_factors, hidden_state, output); + } + default: + context->ReportError(context, "Type %d not currently supported.", + input_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} + } // namespace rnn TfLiteRegistration* Register_RNN() { - static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, - rnn::Prepare, rnn::Eval}; + static TfLiteRegistration r = {rnn::Init, rnn::Free, rnn::Prepare, rnn::Eval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc index fa7ef525db47c93f98951604cd04da66196422d7..96465fcaf0a78527237faa7b82ddbc32ec56d114 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite RNN op. -#include +#include +#include +#include #include #include @@ -122,13 +124,62 @@ static float rnn_golden_output[] = { 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, 0.628881, 3.58099, 1.49974, 0}; +static std::initializer_list rnn_weights = { + 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}; + +static std::initializer_list rnn_recurrent_weights = { + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}; + +static std::initializer_list rnn_bias = { + 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568, + -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, + 0.37197268, 0.61957061, 0.3956964, -0.37609905}; + class RNNOpModel : public SingleOpModel { public: - RNNOpModel(int batches, int units, int size) + RNNOpModel(int batches, int units, int size, + const TensorType& weights = TensorType_FLOAT32, + const TensorType& recurrent_weights = TensorType_FLOAT32) : batches_(batches), units_(units), input_size_(size) { input_ = AddInput(TensorType_FLOAT32); - weights_ = AddInput(TensorType_FLOAT32); - recurrent_weights_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(weights); + recurrent_weights_ = AddInput(recurrent_weights); bias_ = AddInput(TensorType_FLOAT32); hidden_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -173,7 +224,7 @@ class RNNOpModel : public SingleOpModel { int num_units() { return units_; } int num_batches() { return batches_; } - private: + protected: int input_; int weights_; int recurrent_weights_; @@ -186,53 +237,26 @@ class RNNOpModel : public SingleOpModel { int input_size_; }; -TEST(FullyConnectedOpTest, BlackBoxTest) { +// The hybrid model has quantized weights and recurrent_weights. +class HybridRNNOpModel : public RNNOpModel { + public: + HybridRNNOpModel(int batches, int units, int size) + : RNNOpModel(batches, units, size, TensorType_UINT8, TensorType_UINT8) {} + + void SetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_weights_, f); + } +}; + +TEST(RnnOpTest, BlackBoxTest) { RNNOpModel rnn(2, 16, 8); - rnn.SetWeights( - {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, - 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, - 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, - -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, - -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, - -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, - -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, - 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, - 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, - 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, - -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, - 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, - -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, - -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, - 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, - 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, - 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, - -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, - 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, - 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, - -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, - 0.277308, 0.415818}); - - rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, - -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, - 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, - -0.37609905}); - - rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1}); + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.ResetHiddenState(); const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / @@ -256,6 +280,35 @@ TEST(FullyConnectedOpTest, BlackBoxTest) { } } +TEST(HybridRnnOpTest, BlackBoxTest) { + HybridRNNOpModel rnn(2, 16, 8); + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); + + rnn.ResetHiddenState(); + const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / + (rnn.input_size() * rnn.num_batches()); + + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(rnn.input_size(), batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output + i * rnn.num_units(); + float* golden_end = golden_start + rnn.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear( + expected, /*max_abs_error=*/0.0104))); + } +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc index 90edf4f9e3683f2987dd8299a1cd5233caa24479..c8cee88edfdbf42f422f66e4d0ca6eeb5eccbf8d 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -40,9 +40,9 @@ struct BatchToSpaceNDContext { crops = GetInput(context, node, 2); output = GetOutput(context, node, 0); } - TfLiteTensor* input; - TfLiteTensor* block_shape; - TfLiteTensor* crops; + const TfLiteTensor* input; + const TfLiteTensor* block_shape; + const TfLiteTensor* crops; TfLiteTensor* output; }; @@ -66,12 +66,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), kSpatialDimensionNum); - // TODO(ycling): Add crops as part of calculation. Remove check for a crops - // containing all zeroes. - TF_LITE_ENSURE_EQ(context, crops[0], 0); - TF_LITE_ENSURE_EQ(context, crops[1], 0); - TF_LITE_ENSURE_EQ(context, crops[2], 0); - TF_LITE_ENSURE_EQ(context, crops[3], 0); + TF_LITE_ENSURE(context, crops[0] >= 0); + TF_LITE_ENSURE(context, crops[1] >= 0); + TF_LITE_ENSURE(context, crops[2] >= 0); + TF_LITE_ENSURE(context, crops[3] >= 0); // Number of batch must be multiple of (block_shape[0] * block_shape[1]). TF_LITE_ENSURE_EQ(context, @@ -79,8 +77,16 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, const int output_batch_size = input_size->data[0] / (block_shape[0] * block_shape[1]); - const int output_height = input_size->data[1] * block_shape[0]; - const int output_width = input_size->data[2] * block_shape[1]; + + const int crops_top = crops[0]; + const int crops_bottom = crops[1]; + const int crops_left = crops[2]; + const int crops_right = crops[3]; + const int output_height = + input_size->data[1] * block_shape[0] - crops_top - crops_bottom; + const int output_width = + input_size->data[2] * block_shape[1] - crops_left - crops_right; + const int output_channel_size = input_size->data[3]; TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); @@ -157,8 +163,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } break; default: - context->ReportError(context, - "Type is currently not supported by BatchToSpace."); + context->ReportError( + context, "Type %d is currently not supported by BatchToSpace.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_BATCH_TO_SPACE_ND diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc index 8485cde1b40066f2070855bca91ea78a9f80e83c..95b025c1b30cc627cf5858ec17f8ff7c57f7bd95 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc @@ -120,16 +120,16 @@ TEST(BatchToSpaceNDOpTest, InvalidShapeTest) { } TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) { - EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 1}), - "1 != 0"); + EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, -1}), + "crops.3. >= 0 was not true."); } TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) { BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.SetBlockShape({2, 2}); - m.SetCrops({0, 0, 1, 0}); - EXPECT_DEATH(m.Invoke(), "1 != 0"); + m.SetCrops({0, 0, -1, 0}); + EXPECT_DEATH(m.Invoke(), "crops.2. >= 0 was not true."); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index a35ba23cedec437206caa780f4965272d0afb7a8..3425288f027a6fd9eb65f730bc7d039c832ace1c 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -135,7 +135,7 @@ TfLiteStatus CheckLstmTensorDimensions( TF_LITE_ENSURE(context, params->cell_clip >= 0); TF_LITE_ENSURE(context, params->proj_clip >= 0); - TfLiteTensor* input_to_input_weights = + const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, input_to_input_weights_tensor); if (input_to_input_weights) { TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); @@ -143,19 +143,19 @@ TfLiteStatus CheckLstmTensorDimensions( TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); } - TfLiteTensor* input_to_forget_weights = + const TfLiteTensor* input_to_forget_weights = GetInput(context, node, input_to_forget_weights_tensor); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); - TfLiteTensor* input_to_cell_weights = + const TfLiteTensor* input_to_cell_weights = GetInput(context, node, input_to_cell_weights_tensor); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); - TfLiteTensor* recurrent_to_input_weights = + const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(context, node, recurrent_to_input_weights_tensor); if (recurrent_to_input_weights) { TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); @@ -165,7 +165,7 @@ TfLiteStatus CheckLstmTensorDimensions( n_output); } - TfLiteTensor* recurrent_to_forget_weights = + const TfLiteTensor* recurrent_to_forget_weights = GetInput(context, node, recurrent_to_forget_weights_tensor); TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], @@ -173,7 +173,7 @@ TfLiteStatus CheckLstmTensorDimensions( TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], n_output); - TfLiteTensor* recurrent_to_cell_weights = + const TfLiteTensor* recurrent_to_cell_weights = GetInput(context, node, recurrent_to_cell_weights_tensor); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); @@ -189,21 +189,21 @@ TfLiteStatus CheckLstmTensorDimensions( (recurrent_to_input_weights == nullptr)); TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); - TfLiteTensor* cell_to_input_weights = + const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(context, node, cell_to_input_weights_tensor); if (cell_to_input_weights) { TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); } - TfLiteTensor* cell_to_forget_weights = + const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(context, node, cell_to_forget_weights_tensor); if (cell_to_forget_weights) { TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); } - TfLiteTensor* cell_to_output_weights = + const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(context, node, cell_to_output_weights_tensor); if (cell_to_output_weights) { TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); @@ -222,7 +222,7 @@ TfLiteStatus CheckLstmTensorDimensions( TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); // Make sure the input gate bias is present only when not a CIFG-LSTM. - TfLiteTensor* input_gate_bias = + const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(context, node, input_gate_bias_tensor); if (use_cifg) { TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); @@ -231,21 +231,22 @@ TfLiteStatus CheckLstmTensorDimensions( TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); } - TfLiteTensor* forget_gate_bias = + const TfLiteTensor* forget_gate_bias = GetInput(context, node, forget_gate_bias_tensor); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); - TfLiteTensor* cell_bias = GetInput(context, node, cell_gate_bias_tensor); + const TfLiteTensor* cell_bias = + GetInput(context, node, cell_gate_bias_tensor); TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); - TfLiteTensor* output_gate_bias = + const TfLiteTensor* output_gate_bias = GetInput(context, node, output_gate_bias_tensor); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); - TfLiteTensor* projection_weights = + const TfLiteTensor* projection_weights = GetOptionalInputTensor(context, node, projection_weights_tensor); if (projection_weights) { TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); @@ -253,7 +254,7 @@ TfLiteStatus CheckLstmTensorDimensions( TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); } - TfLiteTensor* projection_bias = + const TfLiteTensor* projection_bias = GetOptionalInputTensor(context, node, projection_bias_tensor); if (projection_bias) { TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); @@ -312,20 +313,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TF_LITE_ENSURE(context, input->dims->size > 1); const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; - TfLiteTensor* fw_input_to_output_weights = + const TfLiteTensor* fw_input_to_output_weights = GetInput(context, node, kFwInputToOutputWeightsTensor); const int n_fw_cell = fw_input_to_output_weights->dims->data[0]; TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1], n_input); - TfLiteTensor* fw_recurrent_to_output_weights = + const TfLiteTensor* fw_recurrent_to_output_weights = GetInput(context, node, kFwRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0], @@ -373,7 +374,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { fw_output_state->allocation_type = kTfLiteArenaRwPersistent; fw_cell_state->allocation_type = kTfLiteArenaRwPersistent; - TfLiteTensor* fw_input_to_input_weights = + const TfLiteTensor* fw_input_to_input_weights = GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor); const bool fw_use_cifg = (fw_input_to_input_weights == nullptr); TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2); @@ -388,14 +389,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer, fw_scratch_buffer_size)); // Same for the backward cell. - TfLiteTensor* bw_input_to_output_weights = + const TfLiteTensor* bw_input_to_output_weights = GetInput(context, node, kBwInputToOutputWeightsTensor); const int n_bw_cell = bw_input_to_output_weights->dims->data[0]; TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1], n_input); - TfLiteTensor* bw_recurrent_to_output_weights = + const TfLiteTensor* bw_recurrent_to_output_weights = GetInput(context, node, kBwRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0], @@ -441,7 +442,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { bw_output_state->allocation_type = kTfLiteArenaRwPersistent; bw_cell_state->allocation_type = kTfLiteArenaRwPersistent; - TfLiteTensor* bw_input_to_input_weights = + const TfLiteTensor* bw_input_to_input_weights = GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor); const bool bw_use_cifg = (bw_input_to_input_weights == nullptr); TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2); @@ -463,48 +464,49 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); // Input tensor. - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; // Tensors for the forward cell. - TfLiteTensor* fw_input_to_input_weights = + const TfLiteTensor* fw_input_to_input_weights = GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor); - TfLiteTensor* fw_input_to_forget_weights = + const TfLiteTensor* fw_input_to_forget_weights = GetInput(context, node, kFwInputToForgetWeightsTensor); - TfLiteTensor* fw_input_to_cell_weights = + const TfLiteTensor* fw_input_to_cell_weights = GetInput(context, node, kFwInputToCellWeightsTensor); - TfLiteTensor* fw_input_to_output_weights = + const TfLiteTensor* fw_input_to_output_weights = GetInput(context, node, kFwInputToOutputWeightsTensor); - TfLiteTensor* fw_recurrent_to_input_weights = + const TfLiteTensor* fw_recurrent_to_input_weights = GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor); - TfLiteTensor* fw_recurrent_to_forget_weights = + const TfLiteTensor* fw_recurrent_to_forget_weights = GetInput(context, node, kFwRecurrentToForgetWeightsTensor); - TfLiteTensor* fw_recurrent_to_cell_weights = + const TfLiteTensor* fw_recurrent_to_cell_weights = GetInput(context, node, kFwRecurrentToCellWeightsTensor); - TfLiteTensor* fw_recurrent_to_output_weights = + const TfLiteTensor* fw_recurrent_to_output_weights = GetInput(context, node, kFwRecurrentToOutputWeightsTensor); - TfLiteTensor* fw_cell_to_input_weights = + const TfLiteTensor* fw_cell_to_input_weights = GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor); - TfLiteTensor* fw_cell_to_forget_weights = + const TfLiteTensor* fw_cell_to_forget_weights = GetOptionalInputTensor(context, node, kFwCellToForgetWeightsTensor); - TfLiteTensor* fw_cell_to_output_weights = + const TfLiteTensor* fw_cell_to_output_weights = GetOptionalInputTensor(context, node, kFwCellToOutputWeightsTensor); - TfLiteTensor* fw_input_gate_bias = + const TfLiteTensor* fw_input_gate_bias = GetOptionalInputTensor(context, node, kFwInputGateBiasTensor); - TfLiteTensor* fw_forget_gate_bias = + const TfLiteTensor* fw_forget_gate_bias = GetInput(context, node, kFwForgetGateBiasTensor); - TfLiteTensor* fw_cell_bias = GetInput(context, node, kFwCellGateBiasTensor); - TfLiteTensor* fw_output_gate_bias = + const TfLiteTensor* fw_cell_bias = + GetInput(context, node, kFwCellGateBiasTensor); + const TfLiteTensor* fw_output_gate_bias = GetInput(context, node, kFwOutputGateBiasTensor); - TfLiteTensor* fw_projection_weights = + const TfLiteTensor* fw_projection_weights = GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor); - TfLiteTensor* fw_projection_bias = + const TfLiteTensor* fw_projection_bias = GetOptionalInputTensor(context, node, kFwProjectionBiasTensor); TfLiteTensor* fw_output_state = @@ -513,42 +515,43 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); // Tensors for the backward cell. - TfLiteTensor* bw_input_to_input_weights = + const TfLiteTensor* bw_input_to_input_weights = GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor); - TfLiteTensor* bw_input_to_forget_weights = + const TfLiteTensor* bw_input_to_forget_weights = GetInput(context, node, kBwInputToForgetWeightsTensor); - TfLiteTensor* bw_input_to_cell_weights = + const TfLiteTensor* bw_input_to_cell_weights = GetInput(context, node, kBwInputToCellWeightsTensor); - TfLiteTensor* bw_input_to_output_weights = + const TfLiteTensor* bw_input_to_output_weights = GetInput(context, node, kBwInputToOutputWeightsTensor); - TfLiteTensor* bw_recurrent_to_input_weights = + const TfLiteTensor* bw_recurrent_to_input_weights = GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor); - TfLiteTensor* bw_recurrent_to_forget_weights = + const TfLiteTensor* bw_recurrent_to_forget_weights = GetInput(context, node, kBwRecurrentToForgetWeightsTensor); - TfLiteTensor* bw_recurrent_to_cell_weights = + const TfLiteTensor* bw_recurrent_to_cell_weights = GetInput(context, node, kBwRecurrentToCellWeightsTensor); - TfLiteTensor* bw_recurrent_to_output_weights = + const TfLiteTensor* bw_recurrent_to_output_weights = GetInput(context, node, kBwRecurrentToOutputWeightsTensor); - TfLiteTensor* bw_cell_to_input_weights = + const TfLiteTensor* bw_cell_to_input_weights = GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor); - TfLiteTensor* bw_cell_to_forget_weights = + const TfLiteTensor* bw_cell_to_forget_weights = GetOptionalInputTensor(context, node, kBwCellToForgetWeightsTensor); - TfLiteTensor* bw_cell_to_output_weights = + const TfLiteTensor* bw_cell_to_output_weights = GetOptionalInputTensor(context, node, kBwCellToOutputWeightsTensor); - TfLiteTensor* bw_input_gate_bias = + const TfLiteTensor* bw_input_gate_bias = GetOptionalInputTensor(context, node, kBwInputGateBiasTensor); - TfLiteTensor* bw_forget_gate_bias = + const TfLiteTensor* bw_forget_gate_bias = GetInput(context, node, kBwForgetGateBiasTensor); - TfLiteTensor* bw_cell_bias = GetInput(context, node, kBwCellGateBiasTensor); - TfLiteTensor* bw_output_gate_bias = + const TfLiteTensor* bw_cell_bias = + GetInput(context, node, kBwCellGateBiasTensor); + const TfLiteTensor* bw_output_gate_bias = GetInput(context, node, kBwOutputGateBiasTensor); - TfLiteTensor* bw_projection_weights = + const TfLiteTensor* bw_projection_weights = GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor); - TfLiteTensor* bw_projection_bias = + const TfLiteTensor* bw_projection_bias = GetOptionalInputTensor(context, node, kBwProjectionBiasTensor); TfLiteTensor* bw_output_state = diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc index 17ef2c572ebbfa54ba6856f7eebbcd6fd9e63868..60770ca0aa8b85d9710d26beca3d4d603da5db2f 100644 --- a/tensorflow/contrib/lite/kernels/cast.cc +++ b/tensorflow/contrib/lite/kernels/cast.cc @@ -32,7 +32,7 @@ constexpr int kOutputTensor = 0; TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // TODO(ahentz): these two checks would make the new implementation @@ -69,6 +69,9 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, case kTfLiteFloat32: copyCast(in, out->data.f, num_elements); break; + case kTfLiteBool: + copyCast(in, out->data.b, num_elements); + break; default: // Unsupported type. return kTfLiteError; @@ -77,7 +80,7 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const int num_elements = NumElements(input); TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output)); @@ -90,6 +93,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return copyToTensor(input->data.uint8, output, num_elements); case kTfLiteFloat32: return copyToTensor(input->data.f, output, num_elements); + case kTfLiteBool: + return copyToTensor(input->data.b, output, num_elements); default: // Unsupported type. return kTfLiteError; diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc index 4e56482a371550b6275a6380e2beebe3cef958ff..53e20007378392467356ab29ecb8b217bb7a9e89 100644 --- a/tensorflow/contrib/lite/kernels/cast_test.cc +++ b/tensorflow/contrib/lite/kernels/cast_test.cc @@ -57,6 +57,22 @@ TEST(CastOpModel, CastFloatToInt) { ElementsAreArray({100, 20, 3, 0, 0, 1})); } +TEST(CastOpModel, CastFloatToBool) { + CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_BOOL, {3, 2}}); + m.PopulateTensor(m.input(), {100.f, -1.0f, 0.f, 0.4f, 0.999f, 1.1f}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({true, true, false, true, true, true})); +} + +TEST(CastOpModel, CastBoolToFloat) { + CastOpModel m({TensorType_BOOL, {3, 2}}, {TensorType_FLOAT32, {3, 2}}); + m.PopulateTensor(m.input(), {true, true, false, true, false, true}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f})); +} + } // namespace } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index 2885ce032b4b6a1c63337773678678613f6427b6..f678f48fa5bbbcece6c5b87030d951783378d78f 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -23,6 +23,7 @@ namespace tflite { namespace ops { namespace builtin { namespace comparisons { +namespace { constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; @@ -32,8 +33,8 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Don't support string and bool. @@ -67,9 +68,60 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { GetTensorData(input2), GetTensorDims(input2), \ GetTensorData(output), GetTensorDims(output)); +TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, Equal, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type %d, requires float|int", + input1->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +// TODO(renjieliu): Refactor the logic to avoid duplications. +TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, NotEqual, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type %d, requires float|int", + input1->type); + return kTfLiteError; + } + return kTfLiteOk; +} + TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); // TODO(renjieliu): Support quantized data. @@ -85,15 +137,16 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Does not support type other than float|int"); + "Does not support type %d, requires float|int", + input1->type); return kTfLiteError; } return kTfLiteOk; } TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); // TODO(renjieliu): Support quantized data. @@ -109,15 +162,16 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Does not support type other than float|int"); + "Does not support type %d, requires float|int", + input1->type); return kTfLiteError; } return kTfLiteOk; } TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); // TODO(renjieliu): Support quantized data. @@ -133,15 +187,16 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Does not support type other than float|int"); + "Does not support type %d, requires float|int", + input1->type); return kTfLiteError; } return kTfLiteOk; } TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); // TODO(renjieliu): Support quantized data. @@ -157,14 +212,29 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Does not support type other than float|int"); + "Does not support type %d, requires float|int", + input1->type); return kTfLiteError; } return kTfLiteOk; } +} // namespace } // namespace comparisons +TfLiteRegistration* Register_EQUAL() { + static TfLiteRegistration r = { + nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval}; + return &r; +} + +TfLiteRegistration* Register_NOT_EQUAL() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::NotEqualEval}; + return &r; +} + TfLiteRegistration* Register_GREATER() { static TfLiteRegistration r = {nullptr, nullptr, comparisons::ComparisonPrepare, diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index 835d238d36d1757a27119ae24b3c07232e9d3dc0..bb02e1c812fdc40bf515f1f978e9e39b5a16a4ea 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -21,18 +21,17 @@ limitations under the License. namespace tflite { namespace { -using ::testing::ElementsAreArray; +using ::testing::ElementsAre; -class GreaterOpModel : public SingleOpModel { +class ComparisonOpModel : public SingleOpModel { public: - GreaterOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { + ComparisonOpModel(std::initializer_list input1_shape, + std::initializer_list input2_shape, + TensorType input_type, BuiltinOperator op) { input1_ = AddInput(input_type); input2_ = AddInput(input_type); output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions, - CreateGreaterOptions(builder_).Union()); + ConfigureBuiltinOp(op); BuildInterpreter({input1_shape, input2_shape}); } @@ -46,245 +45,313 @@ class GreaterOpModel : public SingleOpModel { int input1_; int input2_; int output_; + + void ConfigureBuiltinOp(BuiltinOperator op) { + switch (op) { + case BuiltinOperator_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_EqualOptions, + CreateEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_NOT_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_NotEqualOptions, + CreateNotEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_GREATER: { + SetBuiltinOp(op, BuiltinOptions_GreaterOptions, + CreateGreaterOptions(builder_).Union()); + break; + } + case BuiltinOperator_GREATER_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_GreaterEqualOptions, + CreateGreaterEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_LESS: { + SetBuiltinOp(op, BuiltinOptions_LessOptions, + CreateLessOptions(builder_).Union()); + break; + } + case BuiltinOperator_LESS_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_LessEqualOptions, + CreateLessEqualOptions(builder_).Union()); + break; + } + default: { FAIL() << "We shouldn't get here."; } + } + } }; -TEST(ComparisonsTest, GreaterFloat) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); +TEST(ComparisonsTest, EqualFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterInt) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); +TEST(ComparisonsTest, EqualInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterBroadcast) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); +TEST(ComparisonsTest, EqualBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterBroadcastTwoD) { - GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); +TEST(ComparisonsTest, EqualBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, - false, true, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, false, false, + false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class GreaterEqualOpModel : public SingleOpModel { - public: - GreaterEqualOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_GREATER_EQUAL, - BuiltinOptions_GreaterEqualOptions, - CreateGreaterEqualOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } +TEST(ComparisonsTest, NotEqualFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); - int input1() { return input1_; } - int input2() { return input2_; } + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } +TEST(ComparisonsTest, NotEqualInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {1, 2, 7, 5}); + model.Invoke(); - private: - int input1_; - int input2_; - int output_; -}; + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, NotEqualBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, NotEqualBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, true, true, true, true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); +} + +TEST(ComparisonsTest, GreaterFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); +} TEST(ComparisonsTest, GreaterEqualFloat) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualInt) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualBroadcast) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) { - GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, - false, true, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, true, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class LessOpModel : public SingleOpModel { - public: - LessOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_LESS, BuiltinOptions_LessOptions, - CreateLessOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } - - int input1() { return input1_; } - int input2() { return input2_; } - - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } - - private: - int input1_; - int input2_; - int output_; -}; TEST(ComparisonsTest, LessFloat) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessInt) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 6, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessBroadcast) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessBroadcastTwoD) { - LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, - true, false, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class LessEqualOpModel : public SingleOpModel { - public: - LessEqualOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions, - CreateLessEqualOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } - - int input1() { return input1_; } - int input2() { return input2_; } - - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } - - private: - int input1_; - int input2_; - int output_; -}; - TEST(ComparisonsTest, LessEqualFloat) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualInt) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualBroadcast) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualBroadcastTwoD) { - LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, - true, false, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 3b467b3aa284586ab8e67ede55583adffbe06cc7..747c8a62c08d3c4be4c180461727ee0b086ffd47 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -134,7 +134,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). data->need_im2col = (params->stride_width != 1 || params->stride_height != 1 || - filter_width != 1 || filter_height != 1); + params->dilation_width_factor != 1 || + params->dilation_height_factor != 1 || filter_width != 1 || + filter_height != 1); // If we're using the optimized multithreaded EigenTensor implementation of // convolution, it expects the filter weights to be transposed compared to // the normal TF Lite buffer format. Typical TF Lite weights are @@ -212,8 +214,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } else { TF_LITE_ENSURE_EQ(context, bias->type, data_type); } - TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); - TF_LITE_ENSURE_EQ(context, bias->dims->data[0], filter->dims->data[0]); + TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0)); } int channels_out = filter->dims->data[0]; @@ -255,6 +256,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); + TF_LITE_ENSURE(context, real_multiplier < 1.0); QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, &data->output_shift); CalculateActivationRangeUint8(params->activation, output, @@ -489,7 +491,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bias, im2col, hwcn_weights, output); break; default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc index eeda1bc3c5ba2da5b6546ce36925a6f20fc9cbae..a308de055f49eddba99d02e264fad11409a799f4 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -83,9 +83,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { bool hasBias = NumInputs(node) == 3; TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2); - TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - TfLiteTensor* bias = nullptr; + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = nullptr; TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -151,8 +151,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, - &data->output_shift); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); @@ -169,8 +170,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { template void EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteDepthwiseConvParams* params, OpData* data, - TfLiteTensor* input, TfLiteTensor* filter, TfLiteTensor* bias, - TfLiteTensor* output) { + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); @@ -196,8 +197,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, template void EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteDepthwiseConvParams* params, OpData* data, - TfLiteTensor* input, TfLiteTensor* filter, - TfLiteTensor* bias, TfLiteTensor* output) { + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { auto input_offset = -input->params.zero_point; auto filter_offset = -filter->params.zero_point; auto output_offset = output->params.zero_point; @@ -230,9 +231,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - TfLiteTensor* bias = + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; // TODO(aselle): Consider whether float conv and quantized conv should be @@ -247,7 +248,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bias, output); break; default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc index 1439c8bce14ad127ed68dc54991aed8b8bb39383..c00cafb9fbfaf53d4dc301ccd3f21a6c6fd892e6 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc @@ -47,12 +47,6 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel { } output_ = AddOutput(output); - if (input.type != TensorType_FLOAT32) { - // The following is required by quantized inference. It is the unittest's - // responsibility to make sure the output scale falls into the correct - // range. - CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); - } int input_depth = GetShape(input_)[3]; int output_depth = GetShape(filter_)[3]; @@ -176,6 +170,43 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) { })); } +TEST(QuantizedDepthwiseConvolutionOpTest, + SimpleTestQuantizedFilterMultiplierGreaterThan1) { + QuantizedDepthwiseConvolutionOpModel quant_op( + {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64}, + {TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128}, + {TensorType_UINT8, {}, -127, 128}); + DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + std::initializer_list input = { + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }; + std::initializer_list filter = { + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }; + std::initializer_list bias = {1, 2, 3, 4}; + + quant_op.SetInput(input); + quant_op.SetFilter(filter); + quant_op.SetBias(bias); + quant_op.Invoke(); + + float_op.SetInput(input); + float_op.SetFilter(filter); + float_op.SetBias(bias); + float_op.Invoke(); + + EXPECT_THAT(quant_op.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc index e685f2465f627cf30e02564e6f16e1ec69e208e2..672b2170e4990f0a7ca9755071d9d086f5ae5c2b 100644 --- a/tensorflow/contrib/lite/kernels/dequantize.cc +++ b/tensorflow/contrib/lite/kernels/dequantize.cc @@ -32,7 +32,7 @@ struct OpContext { input = GetInput(context, node, 0); output = GetOutput(context, node, 0); } - TfLiteTensor* input; + const TfLiteTensor* input; TfLiteTensor* output; }; diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc index ec380c8e4956e5bcd0d7559bfd8f89a52d9d233c..d264821e30cf622ff5d3d8ad513add46caa9e7ae 100644 --- a/tensorflow/contrib/lite/kernels/div.cc +++ b/tensorflow/contrib/lite/kernels/div.cc @@ -57,8 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, input1->type, input2->type); @@ -80,7 +80,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { template void EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params, const OpData* data, - TfLiteTensor* input1, TfLiteTensor* input2, + const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, @@ -106,22 +106,21 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, #undef TF_LITE_DIV } - - template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { EvalFloat(context, node, params, data, input1, input2, output); } else { - context->ReportError(context, - "Div only supports FLOAT32 and quantized UINT8 now."); + context->ReportError( + context, "Div only supports FLOAT32 and quantized UINT8 now, got %d.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc new file mode 100644 index 0000000000000000000000000000000000000000..98c21ce9d390aaa1f3cb5fdb8f31cbffb1b81d6a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace elementwise { + +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + // Quantized float is not supported yet. + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, + float float_func(float)) { + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = NumElements(input); + const float* in = GetTensorData(input); + const float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = float_func(*in); + return kTfLiteOk; + } + default: { + context->ReportError(context, "Input type is %d, requires float32", + input->type); + return kTfLiteError; + } + } +} + +TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::sin); +} + +TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::log); +} + +} // namespace elementwise + +TfLiteRegistration* Register_SIN() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::SinEval}; + return &r; +} + +TfLiteRegistration* Register_LOG() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::LogEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..10e88d5a31868eeb5f65c7ade1f1c73827dea24a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ElementWiseOpModel : public SingleOpModel { + public: + ElementWiseOpModel(BuiltinOperator op, + std::initializer_list input_shape) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(op, BuiltinOptions_NONE, 0); + BuildInterpreter({input_shape}); + } + + int input() const { return input_; } + int output() const { return output_; } + + private: + int input_; + int output_; +}; + +TEST(ElementWise, Sin) { + ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {0, 3.1415926, -3.1415926, 1}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 0, 0, 0.84147}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +TEST(ElementWise, Log) { + ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 3.1415926, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 1.14473, 0, 0}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc index 4e8cb396d43a58f94b08eb8dd8b05d16fd74fd2f..9410bead5e7a68363d034c22fb2c0eff9f060ef1 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -24,7 +24,8 @@ limitations under the License. // Output: // Output.dim[0] == Tensor[0].dim[0], num of lookups // Output.dim[1] == Tensor[1].dim[1], num of items per row -// Each item in output is a raw bytes copy of corresponding item in input. +// Each item in output is a raw bytes copy of the corresponding item in input, +// or a dequantized value in the case of a uint8 input. // When indices are out of bound, the ops will not succeed. // @@ -51,11 +52,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* lookup = GetInput(context, node, 0); + const TfLiteTensor* lookup = GetInput(context, node, 0); TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); - TfLiteTensor* value = GetInput(context, node, 1); + const TfLiteTensor* value = GetInput(context, node, 1); TF_LITE_ENSURE(context, NumDimensions(value) >= 2); TfLiteTensor* output = GetOutput(context, node, 0); @@ -69,11 +70,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, outputSize); } -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* output = GetOutput(context, node, 0); - TfLiteTensor* lookup = GetInput(context, node, 0); - TfLiteTensor* value = GetInput(context, node, 1); - +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lookup, const TfLiteTensor* value, + TfLiteTensor* output) { const int row_size = SizeOfDimension(value, 0); const int row_bytes = value->bytes / row_size; @@ -91,6 +90,52 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lookup, const TfLiteTensor* value, + TfLiteTensor* output) { + const int row_size = SizeOfDimension(value, 0); + const double scaling_factor = 1.0 / value->params.scale; + + // col_size after we flatten tensor into 2D. + int col_size = 1; + for (int i = 1; i < NumDimensions(value); i++) { + col_size *= SizeOfDimension(value, i); + } + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = lookup->data.i32[i]; + if (idx >= row_size || idx < 0) { + context->ReportError(context, "Embedding Lookup: index out of bounds."); + return kTfLiteError; + } else { + // Dequantize embedding values. + // TODO(alanchiao): refactor scalar multiply into separate function + // for ease of adding a neon equivalent if ever necessary. + for (int j = 0; j < col_size; j++) { + output->data.f[j + i * col_size] = + value->data.uint8[j + idx * col_size] * scaling_factor; + } + } + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* lookup = GetInput(context, node, 0); + const TfLiteTensor* value = GetInput(context, node, 1); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (value->type) { + case kTfLiteFloat32: + return EvalFloat(context, node, lookup, value, output); + case kTfLiteUInt8: + return EvalHybrid(context, node, lookup, value, output); + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } +} + } // namespace embedding_lookup TfLiteRegistration* Register_EMBEDDING_LOOKUP() { diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc index 6c770e7f71efe83eace3640c47e03e0c7ab19e20..d3be36993c3843cf928bf458ae2c8019df7ccf31 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc @@ -81,19 +81,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 5); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* ids = GetInput(context, node, 0); + const TfLiteTensor* ids = GetInput(context, node, 0); TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1); TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32); - TfLiteTensor* indices = GetInput(context, node, 1); + const TfLiteTensor* indices = GetInput(context, node, 1); TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2); TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32); - TfLiteTensor* shape = GetInput(context, node, 2); + const TfLiteTensor* shape = GetInput(context, node, 2); TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1); TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32); - TfLiteTensor* weights = GetInput(context, node, 3); + const TfLiteTensor* weights = GetInput(context, node, 3); TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1); TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32); @@ -102,7 +102,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), SizeOfDimension(weights, 0)); - TfLiteTensor* value = GetInput(context, node, 4); + const TfLiteTensor* value = GetInput(context, node, 4); TF_LITE_ENSURE(context, NumDimensions(value) >= 2); // Mark the output as a dynamic tensor. @@ -139,11 +139,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); TfLiteTensor* output = GetOutput(context, node, 0); - TfLiteTensor* ids = GetInput(context, node, 0); - TfLiteTensor* indices = GetInput(context, node, 1); - TfLiteTensor* dense_shape = GetInput(context, node, 2); - TfLiteTensor* weights = GetInput(context, node, 3); - TfLiteTensor* value = GetInput(context, node, 4); + const TfLiteTensor* ids = GetInput(context, node, 0); + const TfLiteTensor* indices = GetInput(context, node, 1); + const TfLiteTensor* dense_shape = GetInput(context, node, 2); + const TfLiteTensor* weights = GetInput(context, node, 3); + const TfLiteTensor* value = GetInput(context, node, 4); const int lookup_rank = SizeOfDimension(indices, 1); const int embedding_rank = NumDimensions(value); diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 9b501878f196216a61568bfa36e6615f4dd07478..04657fd86323ef1c58d069c06097c7665f55cc87 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -7,13 +7,14 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License +for the specific language governing permissions and limitations under the +License. ==============================================================================*/ // Unit test for TFLite Lookup op. +#include #include #include @@ -29,12 +30,13 @@ namespace { using ::testing::ElementsAreArray; -class EmbeddingLookupOpModel : public SingleOpModel { +class BaseEmbeddingLookupOpModel : public SingleOpModel { public: - EmbeddingLookupOpModel(std::initializer_list index_shape, - std::initializer_list weight_shape) { + BaseEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape, + TensorType weight_type = TensorType_FLOAT32) { input_ = AddInput(TensorType_INT32); - weight_ = AddInput(TensorType_FLOAT32); + weight_ = AddInput(weight_type); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); BuildInterpreter({index_shape, weight_shape}); @@ -44,6 +46,18 @@ class EmbeddingLookupOpModel : public SingleOpModel { PopulateTensor(input_, data); } + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weight_; + int output_; +}; + +class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel; + void Set3DWeightMatrix(const std::function& function) { TfLiteTensor* tensor = interpreter_->tensor(weight_); int rows = tensor->dims->data[0]; @@ -57,20 +71,25 @@ class EmbeddingLookupOpModel : public SingleOpModel { } } } +}; - std::vector GetOutput() { return ExtractVector(output_); } +class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + HybridEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape) + : BaseEmbeddingLookupOpModel(index_shape, weight_shape, + TensorType_UINT8) {} - private: - int input_; - int weight_; - int output_; + void SetWeight(std::initializer_list data) { + SymmetricQuantizeAndPopulate(weight_, data); + } }; // TODO(ahentz): write more tests that exercise the details of the op, such as // lookup errors and variable input shapes. TEST(EmbeddingLookupOpTest, SimpleTest) { EmbeddingLookupOpModel m({3}, {3, 2, 4}); - m.PopulateTensor(0, {1, 0, 2}); + m.SetInput({1, 0, 2}); m.Set3DWeightMatrix( [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); @@ -84,6 +103,69 @@ TEST(EmbeddingLookupOpTest, SimpleTest) { }))); } +TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 8}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + +TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + +TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc index a9e79b742dc2c80ce4ed9a3aa786814265dcb660..ce03cdfe26cac861837c4d534a083787e149cef0 100644 --- a/tensorflow/contrib/lite/kernels/exp.cc +++ b/tensorflow/contrib/lite/kernels/exp.cc @@ -36,7 +36,7 @@ struct ExpContext { input = GetInput(context, node, 0); output = GetOutput(context, node, 0); } - TfLiteTensor* input; + const TfLiteTensor* input; TfLiteTensor* output; }; diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc new file mode 100644 index 0000000000000000000000000000000000000000..ed33012864354cd93eac2344f75d7eca302c8952 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/expand_dims.cc @@ -0,0 +1,113 @@ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +namespace tflite { +namespace ops { +namespace builtin { +namespace expand_dims { +constexpr int kInput = 0; +constexpr int kAxis = 1; +constexpr int kOutput = 0; + +namespace { +TfLiteStatus ExpandTensorDim(TfLiteContext* context, const TfLiteTensor& input, + int axis, TfLiteTensor* output) { + const TfLiteIntArray& input_dims = *input.dims; + if (axis < 0) { + axis = input_dims.size + 1 + axis; + } + TF_LITE_ENSURE(context, axis <= input_dims.size); + + TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims.size + 1); + for (int i = 0; i < output_dims->size; ++i) { + if (i < axis) { + output_dims->data[i] = input_dims.data[i]; + } else if (i == axis) { + output_dims->data[i] = 1; + } else { + output_dims->data[i] = input_dims.data[i - 1]; + } + } + + return context->ResizeTensor(context, output, output_dims); +} + +TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context, + const TfLiteTensor& axis, int* axis_value) { + TF_LITE_ENSURE_EQ(context, NumElements(&axis), 1); + switch (axis.type) { + case kTfLiteInt32: + *axis_value = *GetTensorData(&axis); + return kTfLiteOk; + case kTfLiteInt64: + *axis_value = *GetTensorData(&axis); + return kTfLiteOk; + default: + return kTfLiteError; + } +} + +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, kInput); + const TfLiteTensor* axis = GetInput(context, node, kAxis); + TfLiteTensor* output = GetOutput(context, node, 0); + output->type = input->type; + if (IsConstantTensor(axis)) { + int axis_value; + TF_LITE_ENSURE_OK(context, + GetAxisValueFromTensor(context, *axis, &axis_value)); + return ExpandTensorDim(context, *input, axis_value, output); + } + SetTensorToDynamic(output); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // Just copy input to output. + const TfLiteTensor* input = GetInput(context, node, kInput); + TfLiteTensor* output = GetOutput(context, node, 0); + const TfLiteTensor* axis = GetInput(context, node, kAxis); + if (IsDynamicTensor(output)) { + int axis_value; + TF_LITE_ENSURE_OK(context, + GetAxisValueFromTensor(context, *axis, &axis_value)); + TF_LITE_ENSURE_OK(context, + ExpandTensorDim(context, *input, axis_value, output)); + } + memcpy(output->data.raw, input->data.raw, input->bytes); + return kTfLiteOk; +} + +} // namespace expand_dims +TfLiteRegistration* Register_EXPAND_DIMS() { + static TfLiteRegistration r = {nullptr, nullptr, expand_dims::Prepare, + expand_dims::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b755e8ce293442813b26ec3177162a3c95af2f89 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc @@ -0,0 +1,83 @@ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ExpandDimsOpModel : public SingleOpModel { + public: + ExpandDimsOpModel(std::initializer_list input_shape, + TensorType input_type) { + input_ = AddInput(input_type); + axis_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_EXPAND_DIMS, BuiltinOptions_ExpandDimsOptions, + 0); + BuildInterpreter({input_shape, {1}}); + } + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } + std::vector GetValuesFloat() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int axis_; + int output_; +}; + +TEST(ExpandDimsOpTest, DifferentAxis) { + ExpandDimsOpModel m({2, 2}, TensorType_FLOAT32); + const auto values = {-1.f, 1.f, -2.f, 2.f}; + m.SetInputFloat(values); + m.SetAxis(0); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2})); + + m.SetAxis(1); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2})); + + m.SetAxis(2); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1})); + + m.SetAxis(-1); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc index 4b4395f711614a3b7047dc8f144ca3fa36b8a89b..697b777693e275e36d56f7865c8a3638071591a0 100644 --- a/tensorflow/contrib/lite/kernels/floor.cc +++ b/tensorflow/contrib/lite/kernels/floor.cc @@ -27,7 +27,7 @@ constexpr int kInputTensor = 0; constexpr int kOutputTensor = 0; TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -38,7 +38,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); optimized_ops::Floor(GetTensorData(input), GetTensorDims(input), diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index 470b52b7bc4e65596ba1a3f4f8c036819dcaad28..5a0524bec6a8c8970a4f7300a21e77da84a34822 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -89,9 +89,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, node->inputs->size, 3); TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Check all the parameters of tensor match within themselves and match the @@ -101,17 +101,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { input_size *= input->dims->data[i]; } + TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2); const int batch_size = input_size / filter->dims->data[1]; const int num_units = filter->dims->data[0]; - TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]); + TF_LITE_ENSURE_EQ(context, input_size, batch_size * filter->dims->data[1]); if (bias) { - TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); + TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0)); } - TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2); - TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1); - // Note that quantized inference requires that all tensors have their // parameters set. This is usually done during quantized training. TfLiteType data_type = input->type; @@ -119,6 +117,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); + TF_LITE_ENSURE(context, real_multiplier < 1.0); QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, &data->output_shift); CalculateActivationRangeUint8(params->activation, output, @@ -158,8 +157,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, - TfLiteTensor* input, TfLiteTensor* filter, - TfLiteTensor* bias, TfLiteTensor* output) { + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { int total_input_size = 1; for (int i = 0; i < input->dims->size; i++) { total_input_size *= input->dims->data[i]; @@ -191,8 +190,10 @@ TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node, TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, - TfLiteTensor* input, TfLiteTensor* filter, - TfLiteTensor* bias, TfLiteTensor* input_quantized, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, + TfLiteTensor* input_quantized, TfLiteTensor* output) { // Check the types for this hybrid Op. TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); @@ -217,11 +218,8 @@ TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node, tensor_utils::ZeroVector(output->data.f, batch_size * num_units); } - // TODO(mirkov): change std::minmax_element with a vectorized call. - auto minmax_element = - std::minmax_element(input->data.f, input->data.f + total_input_size); // Save matrix multiplication computation for all zero input. - if (*minmax_element.first == 0.0 && *minmax_element.second == 0.0) { + if (tensor_utils::IsZeroVector(input->data.f, total_input_size)) { tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units, params->activation, output->data.f); @@ -271,8 +269,9 @@ TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node, template TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, - TfLiteTensor* input, TfLiteTensor* filter, - TfLiteTensor* bias, TfLiteTensor* output) { + const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* output) { gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); int32_t input_offset = -input->params.zero_point; @@ -311,8 +310,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, template TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, - TfLiteTensor* input, TfLiteTensor* filter, - TfLiteTensor* bias, TfLiteTensor* output) { + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); @@ -342,9 +341,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); - TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); switch (filter->type) { // Already know in/out types are same. @@ -355,7 +354,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return EvalQuantized(context, node, params, data, input, filter, bias, output); default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + filter->type); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc index 0e4187d1eac64636a2e2b25e9a1cc45c3a4da557..6a2341461f2c627c78bd4783ee27579b59b5fde3 100644 --- a/tensorflow/contrib/lite/kernels/gather.cc +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -35,8 +35,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast(node->builtin_data); - TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* positions = GetInput(context, node, kInputPositions); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* positions = GetInput(context, node, kInputPositions); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Only INT32 positions are supported. TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); @@ -59,8 +59,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); } break; default: - context->ReportError(context, - "Only float32 and string types are supported"); + context->ReportError( + context, "Only float32 and string types are supported, got %d", + input->type); return kTfLiteError; } const int num_dimensions = @@ -81,8 +82,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* positions = GetInput(context, node, kInputPositions); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* positions = GetInput(context, node, kInputPositions); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const int input_rank = NumDimensions(input); #define TF_LITE_GATHER(data_type, index_type) \ diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc index 3b82601d119b2e4946db6e3577300168c7e710b6..41211d41aa85a5a2da6ae96dc6f0337c54fb1a45 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc @@ -60,15 +60,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); - TfLiteTensor* lookup = GetInput(context, node, 0); + const TfLiteTensor* lookup = GetInput(context, node, 0); TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); - TfLiteTensor* key = GetInput(context, node, 1); + const TfLiteTensor* key = GetInput(context, node, 1); TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1); TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32); - TfLiteTensor* value = GetInput(context, node, 2); + const TfLiteTensor* value = GetInput(context, node, 2); TF_LITE_ENSURE(context, NumDimensions(value) >= 1); TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0), SizeOfDimension(value, 0)); @@ -102,9 +102,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* hits = GetOutput(context, node, 1); - TfLiteTensor* lookup = GetInput(context, node, 0); - TfLiteTensor* key = GetInput(context, node, 1); - TfLiteTensor* value = GetInput(context, node, 2); + const TfLiteTensor* lookup = GetInput(context, node, 0); + const TfLiteTensor* key = GetInput(context, node, 1); + const TfLiteTensor* value = GetInput(context, node, 2); const int num_rows = SizeOfDimension(value, 0); const int row_bytes = value->bytes / num_rows; diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index d8340d426ae0bda1dbecc9322650f7c75985126b..75298b995d6184985efc76c60c2f5541e9cbea40 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -302,6 +302,8 @@ cc_library( name = "neon_tensor_utils", srcs = [ "optimized/neon_tensor_utils.cc", + "reference/portable_tensor_utils.cc", + "reference/portable_tensor_utils.h", ], hdrs = [ "common.h", @@ -313,11 +315,11 @@ cc_library( copts = NEON_FLAGS_IF_APPLICABLE + HARD_FP_FLAGS_IF_APPLICABLE, deps = [ ":cpu_check", - ":portable_tensor_utils", ":round", ":types", "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite/kernels:activation_functor", + "//tensorflow/contrib/lite/kernels:op_macros", "@arm_neon_2_x86_sse", "@gemmlowp", ], @@ -418,6 +420,15 @@ cc_library( }), ) +cc_library( + name = "test_util", + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + ":types", + ], +) + cc_test( name = "tensor_utils_test", srcs = ["tensor_utils_test.cc"], @@ -438,6 +449,84 @@ cc_test( ], ) +cc_test( + name = "depthwiseconv_float_test", + srcs = ["depthwiseconv_float_test.cc"], + deps = [ + ":optimized_base", + ":reference_base", + ":test_util", + ":types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "depthwiseconv_quantized_test", + srcs = ["depthwiseconv_quantized_test.cc"], + deps = [ + ":optimized_base", + ":reference_base", + ":test_util", + ":types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "resize_bilinear_test", + srcs = ["resize_bilinear_test.cc"], + tags = ["tflite_not_portable"], + deps = [ + ":optimized_base", + ":reference_base", + ":test_util", + ":types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "softmax_quantized_test", + timeout = "long", + srcs = [ + "softmax_quantized_test.cc", + ], + deps = [ + ":optimized_base", + ":quantization_util", + ":reference_base", + ":test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "logsoftmax_quantized_test", + timeout = "long", + srcs = [ + "logsoftmax_quantized_test.cc", + ], + tags = ["tflite_not_portable"], + deps = [ + ":optimized_base", + ":quantization_util", + ":reference_base", + ":test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "log_quantized_test", + srcs = ["log_quantized_test.cc"], + deps = [ + ":optimized_base", + ":reference_base", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "cpu_check", hdrs = [ diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index ede95dfee069fa078b89d23b68ce1bb264761351..b86ca49c116875672c4516a2a47f7dae511a7116 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -87,12 +87,12 @@ float ActivationFunction(float x) { output_activation_max); } -inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( - int32 x, int32 quantized_multiplier, int right_shift) { +inline int32 MultiplyByQuantizedMultiplierSmallerThanOneExp( + int32 x, int32 quantized_multiplier, int left_shift) { using gemmlowp::RoundingDivideByPOT; using gemmlowp::SaturatingRoundingDoublingHighMul; return RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift); } inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..844ee6a53dd65b81f21ae1ef5b6d04192744a304 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" + +namespace tflite { +namespace { + +// Runs the DepthwiseConv and compares against the reference implementation. +template +void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride, int pad_width, int pad_height, + int depth_multiplier, const Dims<4>& output_dims) { + const int output_buffer_size = RequiredBufferSizeForDims(output_dims); + std::vector output_data(output_buffer_size); + std::vector reference_output_data(output_buffer_size); + reference_ops::DepthwiseConv(input_data, input_dims, filter_data, + filter_dims, bias_data, bias_dims, stride, + pad_width, pad_height, depth_multiplier, + reference_output_data.data(), output_dims); + optimized_ops::DepthwiseConv(input_data, input_dims, filter_data, + filter_dims, bias_data, bias_dims, stride, + pad_width, pad_height, depth_multiplier, + output_data.data(), output_dims); + double sum_abs_diff = 0; + float max_abs_val = 0; + for (int i = 0; i < output_buffer_size; i++) { + sum_abs_diff += std::abs(output_data[i] - reference_output_data[i]); + max_abs_val = std::max(max_abs_val, std::abs(reference_output_data[i])); + } + if (sum_abs_diff != 0.f) { + const float mean_diff = + static_cast(sum_abs_diff / output_buffer_size); + const float relative_error = std::abs(mean_diff) / max_abs_val; + ASSERT_LT(relative_error, 1e-5f); + } +} + +void TestOneDepthwiseConv(FusedActivationFunctionType Ac, + const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride, int pad_width, int pad_height, + int depth_multiplier, const Dims<4>& output_dims) { +#define TOCO_HANDLE_CASE(AC_TYPE) \ + if (AC_TYPE == Ac) { \ + TestOneDepthwiseConv(input_data, input_dims, filter_data, \ + filter_dims, bias_data, bias_dims, stride, \ + pad_width, pad_height, depth_multiplier, \ + output_dims); \ + return; \ + } + TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6) +#undef TOCO_HANDLE_CASE +} + +// This function picks some random DepthwiseConv params, which may or may not +// be legal. If they're not legal, it returns false. If they're legal, +// it runs the DepthwiseConv test and returns true. This allows the caller +// to loop until a test has been run. +bool TryTestOneDepthwiseConv() { + // We have to pick a lot of positive values, where we are particularly + // interested in small values because they are most likely to be special + // cases in optimized implementations, and secondarily because they allow + // tests to run fast, which means we can run more tests and get more + // coverage. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int filter_width = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const int output_depth = input_depth * depth_multiplier; + // The optimized DepthwiseConv implementation currently uses a fixed-size + // accumulator buffer on the stack, with that size. This currently means + // that it does not support larger output depths. It CHECK's for it, + // so it's safe in the sense that if a larger output depth was encountered, + // it would explicitly fail. We just need to adjust our testing to that + // constraint. + const int kMaxSupportedOutputDepth = 1024; + if (output_depth > kMaxSupportedOutputDepth) { + return false; + } + const auto ac = RandomElement(std::vector( + {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu, + FusedActivationFunctionType::kRelu6, + FusedActivationFunctionType::kRelu1})); + Dims<4> input_dims_inference = + MakeDimsForInference(input_depth, input_width, input_height, batch); + Dims<4> output_dims_inference; + int pad_width, pad_height; + const auto padding_type = + UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; + if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width, + filter_height, stride, padding_type, + &output_dims_inference, &pad_width, &pad_height)) { + return false; + } + Dims<4> filter_dims_inference = + MakeDimsForInference(output_depth, filter_width, filter_height, 1); + Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1); + const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference); + const int filter_buffer_size = + RequiredBufferSizeForDims(filter_dims_inference); + std::vector input_data(input_buffer_size); + std::vector filter_data(filter_buffer_size); + std::vector bias_data(output_depth); + const float input_amplitude = 1.f; + const float filter_amplitude = 1.f; + const float bias_amplitude = + filter_width * filter_height * input_amplitude * filter_amplitude; + FillRandom(&input_data, -input_amplitude, input_amplitude); + FillRandom(&filter_data, -filter_amplitude, filter_amplitude); + FillRandom(&bias_data, -bias_amplitude, bias_amplitude); + TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference, + filter_data.data(), filter_dims_inference, + bias_data.data(), bias_dims_inference, stride, pad_width, + pad_height, depth_multiplier, output_dims_inference); + return true; +} + +void TestOneDepthwiseConv() { + while (!TryTestOneDepthwiseConv()) { + } +} + +TEST(TestDepthwiseConv, TestDepthwiseConv) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + TestOneDepthwiseConv(); + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c0fc8433e18fb7f7f89c17380210d94b39ffc94 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -0,0 +1,330 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" + +namespace tflite { +namespace { + +// Runs the DepthwiseConv and compares against the reference implementation. +template +int TestOneDepthwiseConvWithGivenOutputShift( + const std::uint8_t* input_data, const Dims<4>& input_dims, + std::int32_t input_offset, const std::uint8_t* filter_data, + const Dims<4>& filter_dims, std::int32_t filter_offset, + const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + std::int32_t output_offset, std::int32_t output_multiplier, + int output_shift, std::int32_t output_activation_min, + std::int32_t output_activation_max, const Dims<4>& output_dims) { + const int output_buffer_size = RequiredBufferSizeForDims(output_dims); + std::vector output_data(output_buffer_size); + std::vector reference_output_data(output_buffer_size); + reference_ops::DepthwiseConv( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, + reference_output_data.data(), output_dims); + optimized_ops::DepthwiseConv( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data.data(), + output_dims); + int saturated_min = 0; + int saturated_max = 0; + std::vector diff(output_buffer_size); + std::int64_t sum_diff = 0; + std::int64_t sum_abs_diff = 0; + for (int i = 0; i < output_buffer_size; i++) { + diff[i] = static_cast(output_data[i]) - + static_cast(reference_output_data[i]); + sum_diff += diff[i]; + sum_abs_diff += std::abs(diff[i]); + saturated_min += output_data[i] == output_activation_min; + saturated_max += output_data[i] == output_activation_max; + } + // These stats help understand test failures. + std::sort(std::begin(diff), std::end(diff)); + const int min_diff = diff.front(); + const int max_diff = diff.back(); + const int median_diff = diff[diff.size() / 2]; + const float mean_diff = static_cast(sum_diff) / output_buffer_size; + const float mean_abs_diff = + static_cast(sum_abs_diff) / output_buffer_size; + // Normally we should require bit-for-bit exact results. Unfortunately a bug + // in the Intel arm_neon_sse.h translation header that we use for x86 tests + // causes 1-bit inaccuracy in + // the vqrdmulh_n_s32 intrinsic, which causes off-by-1 errors in quantized + // DepthwiseConv ops. So we have to live with a few off-by-one errors for now, + // yet still ensure that no more than a small minority of values are wrong. + EXPECT_TRUE(std::abs(mean_diff) < 1e-5f && mean_abs_diff < 1e-5f && + std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 && + std::abs(max_diff) <= 1); + if (saturated_min > 2 * saturated_max) { + return -1; + } + if (saturated_max > 2 * saturated_min) { + return 1; + } + return 0; +} + +// The point of this function is that we can't practically know which +// output_shift value to pass to test DepthwiseConv. It's not easy to guess (we +// could do some +// statistics for large size, but they would be fragile at smaller sizes), and +// guessing wrong would mean that all the values get saturated so the test +// becomes +// vacuous. So we just bisect our way to reasonable output_shift values. +template +void TestOneDepthwiseConvBisectOutputShift( + const std::uint8_t* input_data, const Dims<4>& input_dims, + std::int32_t input_offset, const std::uint8_t* filter_data, + const Dims<4>& filter_dims, std::int32_t filter_offset, + const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + std::int32_t output_offset, std::int32_t output_multiplier, + int output_activation_bisect_start, int output_activation_bisect_end, + std::int32_t output_activation_min, std::int32_t output_activation_max, + const Dims<4>& output_dims) { + ASSERT_LT(output_activation_bisect_start, output_activation_bisect_end) + << "Bisection failed ?!?!"; + int output_shift_bisect_midpoint = + (output_activation_bisect_start + output_activation_bisect_end) / 2; + int bisect_result = TestOneDepthwiseConvWithGivenOutputShift( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, + output_shift_bisect_midpoint, output_activation_min, + output_activation_max, output_dims); + // At this point we know that the test succeeded (otherwise it would have + // aborted). + if (bisect_result == 0) { + // The result isn't particularly saturated on one or the other side. + // All good, we're done. + return; + } + if (output_activation_bisect_start == output_activation_bisect_end - 1) { + // There is still some saturation on one side, but the bisection is + // finished anyways. We're done; nothing more we can do about it. This + // happens + // in particular when using an activation with a narrow range. + return; + } + // Continue the bisection based on the present result. + int new_output_activation_bisect_start = bisect_result == 1 + ? output_shift_bisect_midpoint + : output_activation_bisect_start; + int new_output_activation_bisect_end = bisect_result == 1 + ? output_activation_bisect_end + : output_shift_bisect_midpoint; + TestOneDepthwiseConvBisectOutputShift( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, + new_output_activation_bisect_start, new_output_activation_bisect_end, + output_activation_min, output_activation_max, output_dims); +} + +template +void TestOneDepthwiseConv( + const std::uint8_t* input_data, const Dims<4>& input_dims, + std::int32_t input_offset, const std::uint8_t* filter_data, + const Dims<4>& filter_dims, std::int32_t filter_offset, + const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + std::int32_t output_offset, std::int32_t output_multiplier, + std::int32_t output_activation_min, std::int32_t output_activation_max, + const Dims<4>& output_dims) { + TestOneDepthwiseConvBisectOutputShift( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, 0, 32, + output_activation_min, output_activation_max, output_dims); +} + +void TestOneDepthwiseConv( + FusedActivationFunctionType Ac, const std::uint8_t* input_data, + const Dims<4>& input_dims, std::int32_t input_offset, + const std::uint8_t* filter_data, const Dims<4>& filter_dims, + std::int32_t filter_offset, const std::int32_t* bias_data, + const Dims<4>& bias_dims, int stride, int pad_width, int pad_height, + int depth_multiplier, std::int32_t output_offset, + std::int32_t output_multiplier, std::int32_t output_activation_min, + std::int32_t output_activation_max, const Dims<4>& output_dims) { +#define TOCO_HANDLE_CASE(AC_TYPE) \ + if (AC_TYPE == Ac) { \ + TestOneDepthwiseConv( \ + input_data, input_dims, input_offset, filter_data, filter_dims, \ + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, \ + depth_multiplier, output_offset, output_multiplier, \ + output_activation_min, output_activation_max, output_dims); \ + return; \ + } + TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6) +#undef TOCO_HANDLE_CASE +} + +bool TryTestDepthwiseConv(int batch, int input_depth, int input_width, + int input_height, int filter_width, int filter_height, + int depth_multiplier, int stride, + PaddingType padding_type) { + const int output_depth = input_depth * depth_multiplier; + // The optimized DepthwiseConv implementation currently uses a fixed-size + // accumulator buffer on the stack, with that size. This currently means + // that it does not support larger output depths. It CHECK's for it, + // so it's safe in the sense that if a larger output depth was encountered, + // it would explicitly fail. We just need to adjust our testing to that + // constraint. + const int kMaxSupportedOutputDepth = 1024; + if (output_depth > kMaxSupportedOutputDepth) { + return false; + } + const auto ac = RandomElement(std::vector( + {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu, + FusedActivationFunctionType::kRelu6, + FusedActivationFunctionType::kRelu1})); + int output_activation_min = 0; + int output_activation_max = 255; + if (ac != FusedActivationFunctionType::kNone && UniformRandomInt(0, 1)) { + output_activation_min = UniformRandomInt(0, 50); + output_activation_max = UniformRandomInt(200, 255); + } + const std::int32_t output_multiplier = + UniformRandomInt(1 << 29, std::numeric_limits::max()); + const std::int32_t input_offset = UniformRandomInt(-256, 0); + const std::int32_t filter_offset = UniformRandomInt(-256, 0); + const std::int32_t output_offset = UniformRandomInt(-256, 0); + Dims<4> input_dims_inference = + MakeDimsForInference(input_depth, input_width, input_height, batch); + Dims<4> output_dims_inference; + int pad_width, pad_height; + if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width, + filter_height, stride, padding_type, + &output_dims_inference, &pad_width, &pad_height)) { + return false; + } + Dims<4> filter_dims_inference = + MakeDimsForInference(output_depth, filter_width, filter_height, 1); + Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1); + const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference); + const int filter_buffer_size = + RequiredBufferSizeForDims(filter_dims_inference); + std::vector input_data(input_buffer_size); + std::vector filter_data(filter_buffer_size); + std::vector bias_data(output_depth); + FillRandom(&input_data); + FillRandom(&filter_data); + FillRandom(&bias_data, -10000, 10000); + TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference, + input_offset, filter_data.data(), filter_dims_inference, + filter_offset, bias_data.data(), bias_dims_inference, + stride, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_activation_min, + output_activation_max, output_dims_inference); + return true; +} + +// This function picks some random DepthwiseConv params, which may or may not +// be legal. If they're not legal, it returns false. If they're legal, +// it runs the DepthwiseConv test and returns true. This allows the caller +// to loop until a test has been run. +bool TryTestOneDepthwiseConv() { + // We have to pick a lot of positive values, where we are particularly + // interested in small values because they are most likely to be special + // cases in optimized implementations, and secondarily because they allow + // tests to run fast, which means we can run more tests and get more + // coverage. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int filter_width = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const auto padding_type = + UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; + + return TryTestDepthwiseConv(batch, input_depth, input_width, input_height, + filter_width, filter_height, depth_multiplier, + stride, padding_type); +} + +// Tests parameters for the 3x3 filter kernel. +bool TryTestOneDepthwiseConv3x3Filter() { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = 8 * ExponentialRandomPositiveInt(0.9f, 10, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int filter_width = 3; + const int filter_height = 3; + const int depth_multiplier = 1; + const int stride = UniformRandomInt(1, 2); + // Although the kernel supports only kValid padding, we test that kSame + // is using the correct code path. + const auto padding_type = + UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; + + return TryTestDepthwiseConv(batch, input_depth, input_width, input_height, + filter_width, filter_height, depth_multiplier, + stride, padding_type); +} + +void TestOneDepthwiseConv() { + while (!TryTestOneDepthwiseConv()) { + } +} + +void TestOneDepthwiseConv3x3Filter() { + while (!TryTestOneDepthwiseConv3x3Filter()) { + } +} + +TEST(TestDepthwiseConv, TestDepthwiseConv) { + const int kTestsToRun = 10 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + TestOneDepthwiseConv(); + } +} + +TEST(TestDepthwiseConv3x3Filter, TestDepthwiseConv) { + const int kTestsToRun = 3 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + TestOneDepthwiseConv3x3Filter(); + } +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index f142374269606bdd3d4184af013749102666ab89..36c25388e8bde721d7644dc83d5b7c490d37b4d3 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" + +#include + #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" namespace tflite { @@ -40,6 +44,70 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, hidden_state_ptr_batch); } +void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, + float input_weights_scale, + const int8_t* recurrent_weights_ptr, + float recurrent_weights_scale, const float* bias_ptr, + int input_size, int num_units, int batch_size, + TfLiteFusedActivation activation, + int8_t* quantized_input_ptr_batch, + int8_t* quantized_hidden_state_ptr_batch, + float* scaling_factors, float* hidden_state_ptr_batch, + float* output_ptr_batch) { + // Output = bias + tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, + output_ptr_batch); + + // Save quantization and matmul computation for all zero input. + if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) { + // Quantize input from float to uint8 + quantization params (scaling + // factor). + float unused_min, unused_max; + // TODO(mirkov,raziel): replace this for-loop with a MACRO (or function) + // whichever is faster. + for (int b = 0; b < batch_size; ++b) { + const int offset = b * input_size; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + scaling_factors[b] *= input_weights_scale; + } + + // Output += input * input_weights + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_weights_ptr, num_units, input_size, quantized_input_ptr_batch, + scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1); + } + + // Save quantization and matmul computation for all zero input. + if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch, + batch_size * num_units)) { + // Quantize hidden_state + float unused_min, unused_max; + for (int b = 0; b < batch_size; ++b) { + const int offset = b * num_units; + tensor_utils::SymmetricQuantizeFloats( + hidden_state_ptr_batch + offset, num_units, + quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + scaling_factors[b] *= recurrent_weights_scale; + } + + // Output += recurrent_weights * hidden_state + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_weights_ptr, num_units, num_units, + quantized_hidden_state_ptr_batch, scaling_factors, batch_size, + output_ptr_batch, /*result_stride=*/1); + } + + // Output = activation(Output) and update hidden_state + tensor_utils::ApplyActivationToVector( + output_ptr_batch, num_units * batch_size, activation, output_ptr_batch); + tensor_utils::VectorBatchVectorAssign(output_ptr_batch, num_units, batch_size, + hidden_state_ptr_batch); +} + void LstmStep( const float* input_ptr_batch, const float* input_to_input_weights_ptr, const float* input_to_forget_weights_ptr, @@ -81,6 +149,7 @@ void LstmStep( input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, input_gate_scratch, /*result_stride=*/1); } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1); @@ -95,8 +164,7 @@ void LstmStep( if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch, - /*result_stride=*/1); + n_batch, input_gate_scratch, /*result_stride=*/1); } tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, @@ -187,5 +255,263 @@ void LstmStep( output_state_ptr); } +// TODO(alanchiao): move this to tensor_utils. +void VectorMultiply(const int8_t* vector, const int v_size, const float scale, + float* result) { + for (int i = 0; i < v_size; ++i) { + *result++ = scale * *vector++; + } +} + +void LstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + + if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, + &unused_min, &unused_max, &scaling_factors[b]); + } + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, forget_gate_scratch, + /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, output_gate_scratch, + /*result_stride=*/1); + } + + if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_output; + tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, + &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + forget_gate_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + output_gate_scratch, /*result_stride=*/1); + } + + // Save quantization and matmul computation for all zero input. + bool is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_input_weights_ptr, n_cell, + 1. / cell_to_input_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_forget_weights_ptr, n_cell, + 1. / cell_to_forget_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, + params->cell_clip, cell_state_ptr); + } + + is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + // For each batch and cell: update the output gate. + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_output_weights_ptr, n_cell, + 1. / cell_to_output_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_cell; + tensor_utils::SymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * projection_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr, + product_scaling_factors, n_batch, output_ptr_batch, + /*result_stride=*/1); + } + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, + params->proj_clip, output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} + } // namespace kernel_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index 3ec60ee57a87833959a34ba95d32df15bea188a4..2a11b37a6069367e8232350c2fc68d4c385e14ba 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -35,6 +35,27 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, TfLiteFusedActivation activation, float* hidden_state_ptr_batch, float* output_ptr_batch); +// Performs a quantized RNN batch inference step. Same as above, but for +// quantization purposes, we also pass in quantized_hidden_state_ptr_batch and +// quantized_input_ptr_batch pointers for temporary storage of the quantized +// values of hidden_state_ptr_batch and input_ptr_batch, respectively. +// These temporary storages are expected to be preallocated to the same size as +// the respective pointers. +// An additional preallocated temporary storage 'scaling_factors' (of size +// batch_size) is used to store the scaling factors of the quantization (used +// for recovery). +// {input,recurrent}_weights_scale params are used for dequantization/recovery. +void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, + float input_weights_scale, + const int8_t* recurrent_weights_ptr, + float recurrent_weights_scale, const float* bias_ptr, + int input_size, int num_units, int batch_size, + TfLiteFusedActivation activation, + int8_t* quantized_input_ptr_batch, + int8_t* quantized_hidden_state_ptr_batch, + float* scaling_factors, float* hidden_state_ptr_batch, + float* output_ptr_batch); + // Performs an LSTM batch inference step for input specified by input_ptr_batch. // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and // biases (*_bias_ptr), and buffers (*_scratch), along with additional @@ -71,6 +92,89 @@ void LstmStep( float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch); +// Same as above but with quantized weight matrices. In detail: +// Input of size 'n_batch * n_input': +// input_ptr_batch +// +// LSTM weights: +// Quantized input weights of size 'n_cell * n_input': +// input_to_input_weights - optional (can be nullptr) +// input_to_forget_weights +// input_to_cell_weights +// input_to_input_weights +// Quantized recurrent weights of size 'n_cell * n_output': +// recurrent_to_input_weights - optional +// recurrent_to_forget_weights +// recurrent_to_cell_weights +// recurrent_to_input_weights +// Quantized peephole weights of size 'n_cell', representing diagonal matrices. +// cell_to_input_weights - optional +// cell_to_cell_weights - optional +// cell_to_output_weights - optional +// Quantized projection weights of size 'n_output * n_cell' +// projection_weights_ptr - optional +// Weight scales (scalars) for each of the weights above. +// input_to_input_weights_scale - optional +// input_to_forget_weights_scale +// input_to_cell_weights_scale +// input_to_output_weights_scale +// recurrent_to_input_weights_scale - optional +// recurrent_to_forget_weights_scale +// recurrent_to_cell_weights_scale +// recurrent_to_output_weights_scale +// cell_to_input_weights_scale, +// cell_to_forget_weights_scale, +// cell_to_output_weights_scale, +// projection_weights_scale - optional +// Gate biases of size 'n_cell': +// input_gate_bias_ptr - optional +// forget_gate_bias_ptr +// cell_gate_bias_ptr +// output_gate_bias_ptr +// +// Temporary pre-allocated storage for quantized values: +// quantized_input_ptr_batch (same size as input_ptr_batch) +// quantized_output_state_ptr (same size as output_state_ptr) +// quantized_cell_state_ptr (same size as cell_state_ptr) +// Temporary pre-allocated storage for recovered values: +// recovered_cell_weights (same size as cell_to_*_weights) +// +// Outputs: +// output_state_ptr - size 'n_batch * n_output' +// cell_state_ptr - size 'n_batch * n_cell' +// output_ptr_batch - size 'n_batch * n_output' +void LstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch); + } // namespace kernel_utils } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e9ff5242a43a8b54e0e6ae167cdcf7a341c918e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc @@ -0,0 +1,333 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS + +#include +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" + +namespace { + +class NumberGenerator { + public: + std::vector RandomIntVector(int n, int min_val, int max_val) { + std::vector vec(n); + double scale = static_cast(max_val + 1 - min_val) / engine_.max(); + for (auto& it : vec) { + it = min_val + std::floor(engine_() * scale); + } + return vec; + } + + std::mt19937 engine_; +}; + +class LogQuantizedTest : public ::testing::Test { + public: + NumberGenerator generator_; +}; + +// input_integer_bits <= 30. output_integer_bits > 0. +inline int32 LogPositiveValuesViaFloat(int32 input_val, int input_integer_bits, + int output_integer_bits) { + const double float_log_sum_of_exps = std::log( + static_cast(input_val) * 0.5 / (1 << (30 - input_integer_bits))); + static constexpr double min_int = + static_cast(std::numeric_limits::min()); + static constexpr double max_int = + static_cast(std::numeric_limits::max()); + double double_result = tflite::TfLiteRound(float_log_sum_of_exps * + (1 << (31 - output_integer_bits))); + return static_cast( + std::min(max_int, std::max(min_int, double_result))); +} + +void CheckOutputData(const std::vector& test_output, + const std::vector& reference_output, + const std::vector& test_input, + const string& check_label, int input_integer_bits, + int output_integer_bits, int tolerance) { + // In the special case of small input, specifically raw value of 5, a rounding + // up leads to difference in the output. We do not aim to be accurate for + // very small input values, and there should be sufficient input fractional + // bits that this is a small input. + static constexpr double error_from_rounding_up = 0.0224585; + const int n = test_output.size(); + ASSERT_EQ(n, reference_output.size()); + for (int i = 0; i < n; ++i) { + // Adjust tolerance when input <= 5*2^-(31-input_integer_bits). + const int adjusted_tolerance = + test_input[i] > 5 + ? tolerance + : std::max(tolerance, static_cast(std::ceil( + error_from_rounding_up * + (1 << (31 - output_integer_bits))))); + ASSERT_LE(std::abs(test_output[i] - reference_output[i]), + adjusted_tolerance) + << "Failure in \"" << check_label << "\" at i=" << i + << ", test_input[i]=" << test_input[i] << "=" + << static_cast(test_input[i]) / (1 << (31 - input_integer_bits)) + << ", test_output[i]=" << test_output[i] << "=" + << static_cast(test_output[i]) / + (1 << (31 - output_integer_bits)) + << ", reference_output[i]=" << reference_output[i] << "=" + << static_cast(reference_output[i]) / + (1 << (31 - output_integer_bits)) + << ", difference[i]=" << std::abs(reference_output[i] - test_output[i]) + << "=" + << static_cast(std::abs(reference_output[i] - test_output[i])) / + (1 << (31 - output_integer_bits)) + << "; tolerance=" << tolerance + << ", adj tolerance=" << adjusted_tolerance; + } +} + +void RightShiftVector(const std::vector& shifts, + std::vector* vec) { + const int n = vec->size(); + ASSERT_EQ(n, shifts.size()); + for (int i = 0; i < n; ++i) { + vec->at(i) = std::max(1, vec->at(i) >> shifts[i]); + } +} + +template +void RunSingleTest(const std::vector& test_input, + const string& check_label, int tolerance) { + const int n = test_input.size(); + std::vector float_gen_output(n, 0); + std::vector reference_output(n, 0); + std::vector optimized_output(n, 0); + + // Workaround the stupid things that intelligent humans do. + // Consequence of __builtin_clz(0u) may equal 31 instead of 32. + std::vector fudged_input(n, 0); + for (int i = 0; i < n; ++i) { + fudged_input[i] = std::max(test_input[i], 2); + } + + for (int i = 0; i < n; ++i) { + reference_output[i] = + tflite::reference_ops::log_x_for_x_greater_than_or_equal_to_1_impl< + OutputIntegerBits, InputIntegerBits>( + gemmlowp::FixedPoint::FromRaw( + fudged_input[i])) + .raw(); + optimized_output[i] = + tflite::optimized_ops::log_x_for_x_greater_than_or_equal_to_1_impl< + OutputIntegerBits, InputIntegerBits>( + gemmlowp::FixedPoint::FromRaw( + fudged_input[i])) + .raw(); + float_gen_output[i] = LogPositiveValuesViaFloat( + fudged_input[i], InputIntegerBits, OutputIntegerBits); + } + // Note that first check is intolerant. + { + std::ostringstream label; + label << check_label << " / optimized vs reference / InputIntegerBits=" + << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; + CheckOutputData( + optimized_output, reference_output, test_input, label.str(), + InputIntegerBits, OutputIntegerBits, 0); + } + { + std::ostringstream label; + label << check_label << " / reference vs float-gen / InputIntegerBits=" + << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; + CheckOutputData( + reference_output, float_gen_output, test_input, label.str(), + InputIntegerBits, OutputIntegerBits, tolerance); + } + { + std::ostringstream label; + label << check_label << " optimized vs float-gen / InputIntegerBits=" + << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; + CheckOutputData( + optimized_output, float_gen_output, test_input, label.str(), + InputIntegerBits, OutputIntegerBits, tolerance); + } +} + +template +void RunSingleTest(const std::vector& test_input, int input_integer_bits, + const string& check_label, int tolerance) { +#define INPUT_CASE(K) \ + case K: \ + return RunSingleTest(test_input, check_label, \ + tolerance) + switch (input_integer_bits) { + INPUT_CASE(0); + INPUT_CASE(1); + INPUT_CASE(2); + INPUT_CASE(3); + INPUT_CASE(4); + INPUT_CASE(5); + INPUT_CASE(6); + INPUT_CASE(7); + INPUT_CASE(8); + INPUT_CASE(9); + INPUT_CASE(10); + INPUT_CASE(11); + INPUT_CASE(12); + INPUT_CASE(13); + INPUT_CASE(14); + INPUT_CASE(15); + INPUT_CASE(16); + INPUT_CASE(17); + INPUT_CASE(18); + INPUT_CASE(19); + INPUT_CASE(20); + INPUT_CASE(21); + INPUT_CASE(22); + INPUT_CASE(23); + INPUT_CASE(24); + INPUT_CASE(25); + INPUT_CASE(26); + INPUT_CASE(27); + INPUT_CASE(28); + INPUT_CASE(29); + default: + ASSERT_LE(input_integer_bits, 30) + << "Input integer bits not handled: " << input_integer_bits; + } +#undef INPUT_CASE +} + +void RunSingleTest(const std::vector& test_input, int input_integer_bits, + int output_integer_bits, const string& check_label, + int tolerance) { +#define OUTPUT_CASE(K) \ + case K: \ + return RunSingleTest(test_input, input_integer_bits, check_label, \ + tolerance) + switch (output_integer_bits) { + OUTPUT_CASE(0); + OUTPUT_CASE(1); + OUTPUT_CASE(2); + OUTPUT_CASE(3); + OUTPUT_CASE(4); + OUTPUT_CASE(5); + OUTPUT_CASE(6); + OUTPUT_CASE(7); + OUTPUT_CASE(8); + OUTPUT_CASE(9); + OUTPUT_CASE(10); + OUTPUT_CASE(11); + OUTPUT_CASE(12); + OUTPUT_CASE(13); + OUTPUT_CASE(14); + OUTPUT_CASE(15); + OUTPUT_CASE(16); + OUTPUT_CASE(17); + OUTPUT_CASE(18); + OUTPUT_CASE(19); + OUTPUT_CASE(20); + OUTPUT_CASE(21); + OUTPUT_CASE(22); + OUTPUT_CASE(23); + OUTPUT_CASE(24); + OUTPUT_CASE(25); + OUTPUT_CASE(26); + OUTPUT_CASE(27); + OUTPUT_CASE(28); + OUTPUT_CASE(29); + default: + ASSERT_LE(input_integer_bits, 30) + << "Input integer bits not handled: " << input_integer_bits; + } +#undef OUTPUT_CASE +} + +void RunUniformTest(int test_size, int input_integer_bits, + int output_integer_bits, const string& check_label, + int tolerance, NumberGenerator* generator) { + std::vector test_data = generator->RandomIntVector( + test_size, 2, std::numeric_limits::max() - 1); + test_data[0] = 2; + test_data[1] = 3; + test_data[2] = 4; + test_data[3] = std::numeric_limits::max() - 2; + test_data[4] = std::numeric_limits::max() - 1; + test_data[5] = std::numeric_limits::max(); + + RunSingleTest(test_data, input_integer_bits, output_integer_bits, + check_label + " / uniform test", tolerance); +} + +void RunUniformShiftUniformTest(int test_size, int input_integer_bits, + int output_integer_bits, + const string& check_label, int tolerance, + NumberGenerator* generator) { + std::vector test_data = generator->RandomIntVector( + test_size, 2, std::numeric_limits::max() - 1); + std::vector shifts = generator->RandomIntVector(test_size, 0, 29); + RightShiftVector(shifts, &test_data); + + RunSingleTest(test_data, input_integer_bits, output_integer_bits, + check_label + " / shifted test", tolerance); +} + +TEST_F(LogQuantizedTest, VariedIntegerBits) { + static constexpr int kVariations = 250; + static constexpr int kRunSize = 250; + static constexpr int kIntegerTolerance = 8; + static constexpr double kOutputFloatTolerance = 7.0e-7; + + std::vector input_integer_bits = + generator_.RandomIntVector(kVariations, 0, 24); + std::vector output_integer_bits = + generator_.RandomIntVector(kVariations, 1, 10); + + for (int i = 0; i < kVariations; ++i) { + int var_output_integer_bits = output_integer_bits[i]; + int tolerance = + std::max(1.0 * kIntegerTolerance, + (1 << (31 - var_output_integer_bits)) * kOutputFloatTolerance); + + RunUniformTest(kRunSize, input_integer_bits[i], var_output_integer_bits, + "VariedIntegerBits", tolerance, &generator_); + RunUniformShiftUniformTest(kRunSize, input_integer_bits[i], + var_output_integer_bits, "VariedIntegerBits", + tolerance, &generator_); + } +} + +TEST_F(LogQuantizedTest, SelectedIntegerBits) { + static constexpr int kInputBits = 12; + static constexpr int kOutputBits = 5; + static constexpr int kRunSize = 100000; + static constexpr int kIntegerTolerance = 4; + + RunUniformTest(kRunSize, kInputBits, kOutputBits, "SelectedIntegerBits", + kIntegerTolerance, &generator_); + RunUniformShiftUniformTest(kRunSize, kInputBits, kOutputBits, + "SelectedIntegerBits", kIntegerTolerance, + &generator_); +} + +} // namespace diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b7531ea2e202cd6fe012e0fa675380775016d38f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -0,0 +1,241 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" + +namespace tflite { +namespace { + +void RunLogSoftmaxFloatReference(const uint8* input_data, + const Dims<4>& dims_common, int32 input_offset, + const double input_scale, int stride, + float beta, uint8* reference_output_data) { + const int ref_buffer_size = RequiredBufferSizeForDims(dims_common); + std::vector reference_dequant_data(ref_buffer_size); + std::vector reference_output_float_data(ref_buffer_size); + + // Reference data generated via Dequant of input into float, and then applying + // float LogSoftmax. + reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale, + reference_dequant_data.data(), dims_common); + optimized_ops::LogSoftmax(reference_dequant_data.data(), dims_common, + reference_output_float_data.data(), dims_common); + // Work with quantized scaling for LogSoftmax, under which 255 represents 0, + // and -16 gets nudged up to 0. + for (int i = 0; i < ref_buffer_size; i++) { + reference_output_data[i] = std::max( + 0, static_cast( + 255 + std::round(16.0f * reference_output_float_data[i]))); + } +} + +void CheckOutputData(const uint8* test_output, const uint8* reference_output, + const Dims<4>& dims_common, const string& check_label, + bool be_exacting) { + const int buffer_size = RequiredBufferSizeForDims(dims_common); + // While calculating some metrics in floating point, we work with quantized + // scaling. + std::vector diff(buffer_size); + int64_t sum_diff = 0; + int64_t sum_abs_diff = 0; + for (int i = 0; i < buffer_size; i++) { + diff[i] = static_cast(test_output[i]) - reference_output[i]; + sum_diff += diff[i]; + sum_abs_diff += std::abs(diff[i]); + } + // These stats help understand test failures. + std::sort(std::begin(diff), std::end(diff)); + const int min_diff = diff.front(); + const int max_diff = diff.back(); + const int median_diff = diff[diff.size() / 2]; + const float mean_diff = static_cast(sum_diff) / buffer_size; + const float mean_abs_diff = static_cast(sum_abs_diff) / buffer_size; + // We either check for bit exactness (against the reference quantized version) + // or for general accuracy, allowing off-by-one (against the float reference). + if (be_exacting) { + ASSERT_TRUE(std::abs(min_diff) == 0 && std::abs(max_diff) == 0) + << check_label << ": " + << "std::abs(min_diff)=" << std::abs(min_diff) + << ", std::abs(max_diff)=" << std::abs(max_diff); + } else { + // For small numbers of samples, the estimates of the means vary more. + // Rather than widen the tolerances, we skip the smaller tests. + ASSERT_TRUE(((std::abs(mean_diff) < 2e-2f && mean_abs_diff < 3e-2f) || + buffer_size < 10000) && + std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 && + std::abs(max_diff) <= 1) + << check_label << ": " + << "buffer_size=" << buffer_size << ", mean_diff=" << mean_diff + << ", mean_abs_diff=" << mean_abs_diff + << ", median_diff=" << median_diff << ", min_diff=" << min_diff + << ", max_diff=" << max_diff; + } +} + +// Runs the LogSoftmax and compares against the float reference implementation +// and the quantized reference implementation. +void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, + int32 input_offset, const double input_scale, + int stride, float beta) { + const int buffer_size = RequiredBufferSizeForDims(dims_common); + std::vector optimized_logsoftmax_output(buffer_size); + std::vector reference_float_logsoftmax_output(buffer_size); + std::vector reference_quant_logsoftmax_output(buffer_size); + + RunLogSoftmaxFloatReference(input_data, dims_common, input_offset, + input_scale, stride, beta, + reference_float_logsoftmax_output.data()); + + int32 input_beta_multiplier; + int input_beta_left_shift; + int32 reverse_scaling_divisor; + int reverse_scaling_right_shift; + static const int kScaledDiffIntegerBits = 5; + tflite::PreprocessLogSoftmaxScaling( + beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier, + &input_beta_left_shift, &reverse_scaling_divisor, + &reverse_scaling_right_shift); + // diff_min has a negative value, and is used to limit the maximum magnitude + // of the diffs, which are <= 0. + const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, + input_beta_left_shift); + + optimized_ops::LogSoftmax(input_data, dims_common, input_beta_multiplier, + input_beta_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, + optimized_logsoftmax_output.data(), dims_common); + reference_ops::LogSoftmax( + input_data, dims_common, input_beta_multiplier, input_beta_left_shift, + reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, + reference_quant_logsoftmax_output.data(), dims_common); + + CheckOutputData(optimized_logsoftmax_output.data(), + reference_float_logsoftmax_output.data(), dims_common, + "Optimized vs float reference", false); + CheckOutputData(optimized_logsoftmax_output.data(), + reference_quant_logsoftmax_output.data(), dims_common, + "Optimized vs quant reference", true); + CheckOutputData(reference_quant_logsoftmax_output.data(), + reference_float_logsoftmax_output.data(), dims_common, + "Quant reference vs float reference", false); +} + +// This function picks some random LogSoftmax params, which are checked for +// desirability. If not acceptable, it returns false. If they're OK, +// it runs the LogSoftmax test and returns true. This allows the caller +// to loop until a test has been run. +// +// Currently we do not reject for any reason. +bool TryOneUniformLogSoftmax() { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // LogSoftmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + static constexpr float beta = 1.0f; + + Dims<4> dims_common = + MakeDimsForInference(input_depth, input_width, input_height, batch); + const int buffer_size = RequiredBufferSizeForDims(dims_common); + + std::vector input_data(buffer_size); + FillRandom(&input_data); + RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + input_scale, stride, beta); + return true; +} + +// See TryOneUniformLogSoftmax() for a general description. +// +// Tests with "skyscraper" input patterns are included for two reasons. (a) +// Bimodal distributions are potentially challenging and perhaps more +// realistic than simple uniform random inputs. (b) Some implementations of +// LogSoftmax may adapt as they traverse the depth, and so we test handling of +// cases where relatively small values are encountered at the beginning and end. +bool TryOneSkyscraperLogSoftmax(bool small_depth) { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // LogSoftmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = small_depth + ? ExponentialRandomPositiveInt(0.75f, 40, 500) + : ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + static constexpr float beta = 1.0f; + // Extra parameters for skyscraper input patterns. + const double middle_proportion = + ExponentialRandomPositiveFloat(0.65f, 0.1, 1.0); + const int middle_min = UniformRandomInt(0, 255); + const int sides_max = UniformRandomInt(0, middle_min); + + Dims<4> dims_common = + MakeDimsForInference(input_depth, input_width, input_height, batch); + const int buffer_size = RequiredBufferSizeForDims(dims_common); + + std::vector input_data(buffer_size); + FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, + sides_max); + RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + input_scale, stride, beta); + return true; +} + +TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneUniformLogSoftmax()) { + } + } +} + +TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperLogSoftmax(false)) { + } + } +} + +TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperLogSoftmax(true)) { + } + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h index dd6932ffe7b7a6f1101f146ce6472b0df4cbda3b..3fd00c89308d3b163111fda004287c715259352f 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -1691,14 +1691,20 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, const int filter_width = ArraySize(filter_dims, 1); const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); +#ifdef USE_NEON + const bool shift_left = (output_shift <= 0); + const int32 multiplier_power_of_two = shift_left ? (1 << -output_shift) : 1; +#endif TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); -#ifdef __aarch64__ +// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on +// Jetson TX-2. This compiler does not support the offsetof() macro. +#if defined(__aarch64__) && !defined(GOOGLE_L4T) // Call kernel optimized for depthwise convolutions using 3x3 filters if // parameters are supported. - if (Fast3x3FilterKernelSupported(input_dims, filter_dims, stride_width, - stride_height, pad_width, pad_height, - depth_multiplier, output_dims)) { + if (Fast3x3FilterKernelSupported( + input_dims, filter_dims, stride_width, stride_height, pad_width, + pad_height, depth_multiplier, output_dims, output_shift)) { DepthwiseConv3x3Filter(input_data, input_dims, input_offset, filter_data, filter_dims, filter_offset, bias_data, bias_dims, stride_width, stride_height, pad_width, pad_height, @@ -1833,12 +1839,20 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, acc[j] = vld1q_s32(acc_buffer + i + 4 * j); } - // Fixed-point multiplication. - for (int j = 0; j < 4; j++) { - acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); - } - for (int j = 0; j < 4; j++) { - acc[j] = RoundingDivideByPOT(acc[j], output_shift); + if (!shift_left) { + // Fixed-point multiplication. + for (int j = 0; j < 4; j++) { + acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); + } + for (int j = 0; j < 4; j++) { + acc[j] = RoundingDivideByPOT(acc[j], output_shift); + } + } else { + // Fixed-point multiplication. + for (int j = 0; j < 4; j++) { + acc[j] = vmulq_n_s32(acc[j], multiplier_power_of_two); + acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); + } } // Add the output offset. for (int j = 0; j < 4; j++) { @@ -1870,12 +1884,21 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, for (; i <= num_output_values - 8; i += 8) { int32x4_t acc0 = vld1q_s32(acc_buffer + i); int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4); - // Fixed-point multiplication. - acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); - acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); - // Rounding right shift. - acc0 = RoundingDivideByPOT(acc0, output_shift); - acc1 = RoundingDivideByPOT(acc1, output_shift); + if (!shift_left) { + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + // Rounding right shift. + acc0 = RoundingDivideByPOT(acc0, output_shift); + acc1 = RoundingDivideByPOT(acc1, output_shift); + } else { + // Fixed-point multiplication. + acc0 = vmulq_n_s32(acc0, multiplier_power_of_two); + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + + acc1 = vmulq_n_s32(acc1, multiplier_power_of_two); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + } // Add the output offset. acc0 = vaddq_s32(acc0, output_offset_vec); acc1 = vaddq_s32(acc1, output_offset_vec); @@ -1899,10 +1922,16 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, // that will have to go through the very slow scalar code. for (; i <= num_output_values - 4; i += 4) { int32x4_t acc = vld1q_s32(acc_buffer + i); - // Fixed-point multiplication. - acc = vqrdmulhq_n_s32(acc, output_multiplier); - // Rounding right shift. - acc = RoundingDivideByPOT(acc, output_shift); + if (!shift_left) { + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + // Rounding right shift. + acc = RoundingDivideByPOT(acc, output_shift); + } else { + // Fixed-point multiplication. + acc = vmulq_n_s32(acc, multiplier_power_of_two); + acc = vqrdmulhq_n_s32(acc, output_multiplier); + } // Add the output offset. acc = vaddq_s32(acc, output_offset_vec); // Apply the activation function. @@ -1923,8 +1952,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, // Handle leftover values, one by one. This is very slow. for (; i < num_output_values; i++) { int32 acc = acc_buffer[i]; - acc = MultiplyByQuantizedMultiplierSmallerThanOne( - acc, output_multiplier, output_shift); + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, + -output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 55e0d5c3aa9ebb8b46403550e190b00a54cb53e5..a7b0d805a3acd35b592a35ba4266dfff4eb992cd 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -23,3848 +23,2912 @@ limitations under the License. namespace tflite { namespace optimized_ops { -#ifdef __aarch64__ - -inline void preload_l1_keep(const uint8* ptr) { -#ifdef GEMMLOWP_ARM_64 - asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :); -#else - gemmlowp::Prefetch(ptr); -#endif -} - -// Implementation of quantized DepthwiseConv for 3x3 filters. - -// Below are helper structs to remove the use of arrays. -// There is an llvm bug that causes significant slowdown when using arrays for -// NEON intrinsics vector data types. -// See: https://bugs.llvm.org/show_bug.cgi?id=34945 - -struct Int32x8 { - int32x4_t low, high; -}; - -struct Filter3x3x8 { - int16x8_t f0, f1, f2, f3, f4, f5, f6, f7, f8; -}; - -// Loads 3x3 filter of depth 8 and adds filter offsets. -inline Filter3x3x8 Load3x3Filter(const uint8* filter_ptr, int32 filter_offset, - int output_depth) { - Filter3x3x8 filter; - - uint8x8_t temp_u8_0, temp_u8_1, temp_u8_2, temp_u8_3, temp_u8_4, temp_u8_5, - temp_u8_6, temp_u8_7, temp_u8_8; - int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset); - - temp_u8_0 = vld1_u8(filter_ptr + 0 * output_depth); - temp_u8_1 = vld1_u8(filter_ptr + 1 * output_depth); - temp_u8_2 = vld1_u8(filter_ptr + 2 * output_depth); - temp_u8_3 = vld1_u8(filter_ptr + 3 * output_depth); - temp_u8_4 = vld1_u8(filter_ptr + 4 * output_depth); - temp_u8_5 = vld1_u8(filter_ptr + 5 * output_depth); - temp_u8_6 = vld1_u8(filter_ptr + 6 * output_depth); - temp_u8_7 = vld1_u8(filter_ptr + 7 * output_depth); - temp_u8_8 = vld1_u8(filter_ptr + 8 * output_depth); - - filter.f0 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_0)); - filter.f1 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_1)); - filter.f2 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_2)); - filter.f3 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_3)); - filter.f4 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_4)); - filter.f5 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_5)); - filter.f6 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_6)); - filter.f7 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_7)); - filter.f8 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_8)); - - filter.f0 = vaddq_s16(filter.f0, filter_offset_vec); - filter.f1 = vaddq_s16(filter.f1, filter_offset_vec); - filter.f2 = vaddq_s16(filter.f2, filter_offset_vec); - filter.f3 = vaddq_s16(filter.f3, filter_offset_vec); - filter.f4 = vaddq_s16(filter.f4, filter_offset_vec); - filter.f5 = vaddq_s16(filter.f5, filter_offset_vec); - filter.f6 = vaddq_s16(filter.f6, filter_offset_vec); - filter.f7 = vaddq_s16(filter.f7, filter_offset_vec); - filter.f8 = vaddq_s16(filter.f8, filter_offset_vec); - - return filter; -} - -// Applies activation, offset and downquantize on a set of accumulator -// registers that correspond to a 2x2 output of depth 8. -// Stores results to output. -inline void DownquantizeAndStore2x2Output( - Int32x8 acc_0, Int32x8 acc_1, Int32x8 acc_2, Int32x8 acc_3, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - using gemmlowp::RoundingDivideByPOT; - const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); - const int32x4_t output_activation_min_vec = - vdupq_n_s32(output_activation_min); - const int32x4_t output_activation_max_vec = - vdupq_n_s32(output_activation_max); - - // Fixed-point multiplication. - acc_0.low = vqrdmulhq_n_s32(acc_0.low, output_multiplier); - acc_0.high = vqrdmulhq_n_s32(acc_0.high, output_multiplier); - acc_1.low = vqrdmulhq_n_s32(acc_1.low, output_multiplier); - acc_1.high = vqrdmulhq_n_s32(acc_1.high, output_multiplier); - acc_2.low = vqrdmulhq_n_s32(acc_2.low, output_multiplier); - acc_2.high = vqrdmulhq_n_s32(acc_2.high, output_multiplier); - acc_3.low = vqrdmulhq_n_s32(acc_3.low, output_multiplier); - acc_3.high = vqrdmulhq_n_s32(acc_3.high, output_multiplier); - - acc_0.low = RoundingDivideByPOT(acc_0.low, output_shift); - acc_0.high = RoundingDivideByPOT(acc_0.high, output_shift); - acc_1.low = RoundingDivideByPOT(acc_1.low, output_shift); - acc_1.high = RoundingDivideByPOT(acc_1.high, output_shift); - acc_2.low = RoundingDivideByPOT(acc_2.low, output_shift); - acc_2.high = RoundingDivideByPOT(acc_2.high, output_shift); - acc_3.low = RoundingDivideByPOT(acc_3.low, output_shift); - acc_3.high = RoundingDivideByPOT(acc_3.high, output_shift); - - // Add the output offset. - acc_0.low = vaddq_s32(acc_0.low, output_offset_vec); - acc_0.high = vaddq_s32(acc_0.high, output_offset_vec); - acc_1.low = vaddq_s32(acc_1.low, output_offset_vec); - acc_1.high = vaddq_s32(acc_1.high, output_offset_vec); - acc_2.low = vaddq_s32(acc_2.low, output_offset_vec); - acc_2.high = vaddq_s32(acc_2.high, output_offset_vec); - acc_3.low = vaddq_s32(acc_3.low, output_offset_vec); - acc_3.high = vaddq_s32(acc_3.high, output_offset_vec); - - // Apply the activation function. - acc_0.low = vmaxq_s32(acc_0.low, output_activation_min_vec); - acc_0.high = vmaxq_s32(acc_0.high, output_activation_min_vec); - acc_1.low = vmaxq_s32(acc_1.low, output_activation_min_vec); - acc_1.high = vmaxq_s32(acc_1.high, output_activation_min_vec); - acc_2.low = vmaxq_s32(acc_2.low, output_activation_min_vec); - acc_2.high = vmaxq_s32(acc_2.high, output_activation_min_vec); - acc_3.low = vmaxq_s32(acc_3.low, output_activation_min_vec); - acc_3.high = vmaxq_s32(acc_3.high, output_activation_min_vec); - - acc_0.low = vminq_s32(acc_0.low, output_activation_max_vec); - acc_0.high = vminq_s32(acc_0.high, output_activation_max_vec); - acc_1.low = vminq_s32(acc_1.low, output_activation_max_vec); - acc_1.high = vminq_s32(acc_1.high, output_activation_max_vec); - acc_2.low = vminq_s32(acc_2.low, output_activation_max_vec); - acc_2.high = vminq_s32(acc_2.high, output_activation_max_vec); - acc_3.low = vminq_s32(acc_3.low, output_activation_max_vec); - acc_3.high = vminq_s32(acc_3.high, output_activation_max_vec); - - // Saturating cast to uint8 and store to destination. - int16x4_t acc_0_low_s16 = vqmovn_s32(acc_0.low); - int16x4_t acc_0_high_s16 = vqmovn_s32(acc_0.high); - int16x4_t acc_1_low_s16 = vqmovn_s32(acc_1.low); - int16x4_t acc_1_high_s16 = vqmovn_s32(acc_1.high); - int16x4_t acc_2_low_s16 = vqmovn_s32(acc_2.low); - int16x4_t acc_2_high_s16 = vqmovn_s32(acc_2.high); - int16x4_t acc_3_low_s16 = vqmovn_s32(acc_3.low); - int16x4_t acc_3_high_s16 = vqmovn_s32(acc_3.high); - - int16x8_t res_0_s16 = vcombine_s16(acc_0_low_s16, acc_0_high_s16); - int16x8_t res_1_s16 = vcombine_s16(acc_1_low_s16, acc_1_high_s16); - int16x8_t res_2_s16 = vcombine_s16(acc_2_low_s16, acc_2_high_s16); - int16x8_t res_3_s16 = vcombine_s16(acc_3_low_s16, acc_3_high_s16); - - uint8x8_t res_0_u8 = vqmovun_s16(res_0_s16); - uint8x8_t res_1_u8 = vqmovun_s16(res_1_s16); - uint8x8_t res_2_u8 = vqmovun_s16(res_2_s16); - uint8x8_t res_3_u8 = vqmovun_s16(res_3_s16); - - vst1_u8(output_ptr, res_0_u8); - vst1_u8(output_ptr + output_depth, res_1_u8); - vst1_u8(output_ptr + output_depth * output_width, res_2_u8); - vst1_u8(output_ptr + output_depth * output_width + output_depth, res_3_u8); -} - -inline void DownquantizeAndStore(Int32x8 acc, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, - uint8* output_ptr) { - using gemmlowp::RoundingDivideByPOT; - const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); - const int32x4_t output_activation_min_vec = - vdupq_n_s32(output_activation_min); - const int32x4_t output_activation_max_vec = - vdupq_n_s32(output_activation_max); - - acc.low = vqrdmulhq_n_s32(acc.low, output_multiplier); - acc.high = vqrdmulhq_n_s32(acc.high, output_multiplier); - - acc.low = RoundingDivideByPOT(acc.low, output_shift); - acc.high = RoundingDivideByPOT(acc.high, output_shift); - - acc.low = vaddq_s32(acc.low, output_offset_vec); - acc.high = vaddq_s32(acc.high, output_offset_vec); - - acc.low = vmaxq_s32(acc.low, output_activation_min_vec); - acc.high = vmaxq_s32(acc.high, output_activation_min_vec); - - acc.low = vminq_s32(acc.low, output_activation_max_vec); - acc.high = vminq_s32(acc.high, output_activation_max_vec); - - int16x4_t acc_low_s16 = vqmovn_s32(acc.low); - int16x4_t acc_high_s16 = vqmovn_s32(acc.high); - - int16x8_t res_s16 = vcombine_s16(acc_low_s16, acc_high_s16); - uint8x8_t res_u8 = vqmovun_s16(res_s16); - vst1_u8(output_ptr, res_u8); -} - -inline void DownquantizeAndStore2Output( - Int32x8 acc_0, Int32x8 acc_1, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_ptr, int output_ptr_offset) { - { - using gemmlowp::RoundingDivideByPOT; - const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); - const int32x4_t output_activation_min_vec = - vdupq_n_s32(output_activation_min); - const int32x4_t output_activation_max_vec = - vdupq_n_s32(output_activation_max); - - // Fixed-point multiplication. - acc_0.low = vqrdmulhq_n_s32(acc_0.low, output_multiplier); - acc_0.high = vqrdmulhq_n_s32(acc_0.high, output_multiplier); - acc_1.low = vqrdmulhq_n_s32(acc_1.low, output_multiplier); - acc_1.high = vqrdmulhq_n_s32(acc_1.high, output_multiplier); - - acc_0.low = RoundingDivideByPOT(acc_0.low, output_shift); - acc_0.high = RoundingDivideByPOT(acc_0.high, output_shift); - acc_1.low = RoundingDivideByPOT(acc_1.low, output_shift); - acc_1.high = RoundingDivideByPOT(acc_1.high, output_shift); - - // Add the output offset. - acc_0.low = vaddq_s32(acc_0.low, output_offset_vec); - acc_0.high = vaddq_s32(acc_0.high, output_offset_vec); - acc_1.low = vaddq_s32(acc_1.low, output_offset_vec); - acc_1.high = vaddq_s32(acc_1.high, output_offset_vec); - - // Apply the activation function. - acc_0.low = vmaxq_s32(acc_0.low, output_activation_min_vec); - acc_0.high = vmaxq_s32(acc_0.high, output_activation_min_vec); - acc_1.low = vmaxq_s32(acc_1.low, output_activation_min_vec); - acc_1.high = vmaxq_s32(acc_1.high, output_activation_min_vec); - - acc_0.low = vminq_s32(acc_0.low, output_activation_max_vec); - acc_0.high = vminq_s32(acc_0.high, output_activation_max_vec); - acc_1.low = vminq_s32(acc_1.low, output_activation_max_vec); - acc_1.high = vminq_s32(acc_1.high, output_activation_max_vec); - } - - // Saturating cast to uint8 and store to destination. - int16x8_t res_0_s16; - { - int16x4_t acc_0_low_s16 = vqmovn_s32(acc_0.low); - int16x4_t acc_0_high_s16 = vqmovn_s32(acc_0.high); - res_0_s16 = vcombine_s16(acc_0_low_s16, acc_0_high_s16); - } - - int16x8_t res_1_s16; - { - int16x4_t acc_1_low_s16 = vqmovn_s32(acc_1.low); - int16x4_t acc_1_high_s16 = vqmovn_s32(acc_1.high); - res_1_s16 = vcombine_s16(acc_1_low_s16, acc_1_high_s16); - } - - uint8x8_t res_0_u8 = vqmovun_s16(res_0_s16); - uint8x8_t res_1_u8 = vqmovun_s16(res_1_s16); - vst1_u8(output_ptr, res_0_u8); - vst1_u8(output_ptr + output_ptr_offset, res_1_u8); -} - -// Performs multiply accumulate on 3 inputs of depth 8. -inline Int32x8 MultiplyAccumulateRow(Int32x8 accum, int16x8_t f0, int16x8_t f1, - int16x8_t f2, int16x8_t i0, int16x8_t i1, - int16x8_t i2) { - accum.low = vmlal_s16(accum.low, vget_low_s16(f0), vget_low_s16(i0)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f0), vget_high_s16(i0)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f1), vget_low_s16(i1)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f1), vget_high_s16(i1)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f2), vget_low_s16(i2)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f2), vget_high_s16(i2)); - return accum; -} - -// Performs multiply accumulate on 3 inputs of depth 8. -inline Int32x8 MultiplyAccumulate3x3Filter(const Filter3x3x8& f, int16x8_t i0, - int16x8_t i1, int16x8_t i2, - int16x8_t i3, int16x8_t i4, - int16x8_t i5, int16x8_t i6, - int16x8_t i7, int16x8_t i8, - Int32x8 accum) { - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f0), vget_low_s16(i0)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f0), vget_high_s16(i0)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f1), vget_low_s16(i1)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f1), vget_high_s16(i1)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f2), vget_low_s16(i2)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f2), vget_high_s16(i2)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f3), vget_low_s16(i3)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f3), vget_high_s16(i3)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f4), vget_low_s16(i4)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f4), vget_high_s16(i4)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f5), vget_low_s16(i5)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f5), vget_high_s16(i5)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f6), vget_low_s16(i6)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f6), vget_high_s16(i6)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f7), vget_low_s16(i7)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f7), vget_high_s16(i7)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f8), vget_low_s16(i8)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f8), vget_high_s16(i8)); - return accum; -} - -inline void DotProductAndStore(const Filter3x3x8& filter, int16x8_t i0, - int16x8_t i1, int16x8_t i2, int16x8_t i3, - int16x8_t i4, int16x8_t i5, int16x8_t i6, - int16x8_t i7, int16x8_t i8, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr) { - Int32x8 acc; - acc.low = vld1q_s32(bias_ptr); - acc.high = vld1q_s32(bias_ptr + 4); - - acc = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i3, i4, i5, i6, i7, i8, - acc); - - DownquantizeAndStore(acc, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, - output_ptr); -} - -// Performs multiply-accumulate on a 3x4 input for 2 horizontal outputs. -inline void DotProductAndStore2xStride1( - const Filter3x3x8& filter, int16x8_t i0, int16x8_t i1, int16x8_t i2, - int16x8_t i3, int16x8_t i4, int16x8_t i5, int16x8_t i6, int16x8_t i7, - int16x8_t i8, int16x8_t i9, int16x8_t i10, int16x8_t i11, - const int32* bias_ptr, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_ptr, int output_ptr_offset) { - Int32x8 acc_0, acc_1; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i4, i5, i6, i8, i9, - i10, acc_0); - acc_1 = MultiplyAccumulate3x3Filter(filter, i1, i2, i3, i5, i6, i7, i9, i10, - i11, acc_1); - DownquantizeAndStore2Output(acc_0, acc_1, output_offset, output_multiplier, - output_shift, output_activation_min, - output_activation_max, output_ptr, - output_ptr_offset); -} - -// Performs multiply-accumulate on a 4x3 input for 2 vertical outputs. -inline void DotProductAndStore2yStride1( - const Filter3x3x8& filter, int16x8_t i0, int16x8_t i1, int16x8_t i2, - int16x8_t i3, int16x8_t i4, int16x8_t i5, int16x8_t i6, int16x8_t i7, - int16x8_t i8, int16x8_t i9, int16x8_t i10, int16x8_t i11, - const int32* bias_ptr, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_ptr, int output_ptr_offset) { - Int32x8 acc_0, acc_1; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i3, i4, i5, i6, i7, - i8, acc_0); - acc_1 = MultiplyAccumulate3x3Filter(filter, i3, i4, i5, i6, i7, i8, i9, i10, - i11, acc_1); - DownquantizeAndStore2Output(acc_0, acc_1, output_offset, output_multiplier, - output_shift, output_activation_min, - output_activation_max, output_ptr, - output_ptr_offset); -} - -// A kernel that is optimized on the number of output cells in the x and y -// direction, and the stride. Assumes 3x3 filters of 8 depth. -template -struct ConvKernel3x3FilterDepth8 {}; - -template <> -struct ConvKernel3x3FilterDepth8<8, 8, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - // To process 8x8 outputs using a 3x3 filter, we require 10x10 inputs. - // Load inputs for the first 2 filters on the top left, then slide to - // the right, down, left, down, right, etc. in a snake-like path. This - // minimizes the total number of loads. - // - // INPUT OUTPUT - // |\----------------\ |\------------\ - // | \ \ | \ \ - // | \----------------\ | \------------\ - // | | 0 ... 9 | | | 0 ... 7 | - // | | 10 ... 19 | ---> | | 8 ... 15 | - // | | 20 ... 29 | \ | .. ... .. | - // \ | .. ... .. | \| 56 ... 63 | - // \| 90 ... 109 | |------------| - // |----------------| - // - // The first set of loads corresponds to: - // - // INPUT OUTPUT - // |\----------------- |\----------- - // | \ | \ - // | \----------------- | \---------- - // | | 0 1 2 3 ... | | 0 1 ... - // | | 10 11 12 13 ... ---> | | .. ... - // | | 20 21 22 23 ... | .. ... - // | | .. ... ... - // - // The next set of loads correspond to a sliding window to the right. - // It loads inputs 4, 5, 14, 15, 23, 24 and keeps 2, 3, 12, 13, and 22: - // - // INPUT OUTPUT - // |\------------------- |\------------- - // | \ | \ - // | \------------------- | \------------ - // | | .. 2 3 4 5 ... | | .. 2 3 ... - // | | .. 12 13 14 15 ... ---> | | .. ... - // | | .. 21 22 23 24 ... | .. ... - // | | .. ... ... - // - // And so on... - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. Referring to the - // indexes in the diagram above, this corresponds to outputs (0) and (1). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Slide to the right for outputs x = [2, 3], y = 0. Referring to the - // indexes in the diagram above, this corresponds to outputs (2) and (3). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - - // Slide to the right again for outputs x = [4, 5], y = 0. Referring to the - // indexes in the diagram above, this corresponds to outputs (4) and (5). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 4 * output_depth, output_depth); - - // Slide to the right one last time for outputs x = [6, 7], y = 0. - // Referring to the indexes in the diagram above, this corresponds to - // outputs (6) and (7). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 6 * output_depth, output_depth); - - // Slide to down for outputs x = [6, 7], y = 1. Referring to the indexes in - // the diagram above, this corresponds to outputs (14) and (15). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 6 * output_depth + output_row_size, - output_depth); - - // Slide left for outputs x = [4, 5], y = 1. Referring to the indexes in - // the diagram above, this corresponds to outputs (12) and (13). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 4 * output_depth + output_row_size, - output_depth); - - // Slide left again for outputs x = [2, 3], y = 1. Referring to the indexes - // in the diagram above, this corresponds to outputs (10) and (11). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth + output_row_size, - output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 1. Referring to the - // indexes in the diagram above, this corresponds to outputs (8) and (9). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + output_row_size, output_depth); - - // Slide down for outputs x = [0, 1], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (16) and (17). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_row_size, output_depth); - - // Slide right for outputs x = [2, 3], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (18) and (19). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 2 * output_row_size, output_depth); - - // Slide right for outputs x = [4, 5], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (20) and (21). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 2 * output_row_size, output_depth); - - // Slide right one more time for outputs x = [6, 7], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (22) and (23). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 2 * output_row_size, output_depth); - - // Slide down for outputs x = [6, 7], y = 3. Referring to the indexes in - // the diagram above, this corresponds to outputs (30) and (31). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 3 * output_row_size, output_depth); - - // Slide left for outputs x = [4, 5], y = 3. Referring to the indexes in - // the diagram above, this corresponds to outputs (28) and (29). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 3 * output_row_size, output_depth); - - // Slide left for outputs x = [2, 3], y = 3. Referring to the indexes in - // the diagram above, this corresponds to outputs (26) and (27). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 3 * output_row_size, output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 3. Referring to the - // indexes in the diagram above, this corresponds to outputs (24) and (25). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 3 * output_row_size, output_depth); - - // Slide down for outputs x = [0, 1], y = 4. Referring to the indexes in - // the diagram above, this corresponds to outputs (32) and (33). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 4 * output_row_size, output_depth); - - // Slide right for outputs x = [2, 3], y = 4. Referring to the indexes in - // the diagram above, this corresponds to outputs (34) and (35). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 4 * output_row_size, output_depth); - - // Slide right for outputs x = [4, 5], y = 4. Referring to the indexes in - // the diagram above, this corresponds to outputs (36) and (37). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 4 * output_row_size, output_depth); - - // Slide right one more time for outputs x = [6, 7], y = 4. Referring to the - // indexes in the diagram above, this corresponds to outputs (38) and (39). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 4 * output_row_size, output_depth); - - // Slide down for outputs x = [6, 7], y = 5. Referring to the indexes in - // the diagram above, this corresponds to outputs (46) and (47). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 5 * output_row_size, output_depth); - - // Slide left for outputs x = [4, 5], y = 5. Referring to the indexes in - // the diagram above, this corresponds to outputs (44) and (45). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 5 * output_row_size, output_depth); - - // Slide left for outputs x = [2, 3], y = 5. Referring to the indexes in - // the diagram above, this corresponds to outputs (42) and (43). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 5 * output_row_size, output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 5. Referring to the - // indexes in the diagram above, this corresponds to outputs (40) and (41). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 5 * output_row_size, output_depth); - - // Slide down for outputs x = [0, 1], y = 6. Referring to the indexes in - // the diagram above, this corresponds to outputs (48) and (49). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 8 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 6 * output_row_size, output_depth); - - // Slide right for outputs x = [2, 3], y = 6. Referring to the indexes in - // the diagram above, this corresponds to outputs (50) and (51). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 6 * output_row_size, output_depth); - - // Slide right for outputs x = [4, 5], y = 6. Referring to the indexes in - // the diagram above, this corresponds to outputs (52) and (53). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 6 * output_row_size, output_depth); - - // Slide right one more time for outputs x = [6, 7], y = 6. Referring to the - // indexes in the diagram above, this corresponds to outputs (54) and (55). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 6 * output_row_size, output_depth); - - // Slide down for outputs x = [6, 7], y = 7. Referring to the indexes in the - // diagram above, this corresponds to outputs (62) and (63). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 9 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 7 * output_row_size, output_depth); - - // Slide left for outputs x = [4, 5], y = 7. Referring to the indexes in the - // diagram above, this corresponds to outputs (60) and (61). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 7 * output_row_size, output_depth); - - // Slide left for outputs x = [2, 3], y = 7. Referring to the indexes in the - // diagram above, this corresponds to outputs (58) and (59). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 7 * output_row_size, output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 7. Referring to the - // indexes in the diagram above, this corresponds to outputs (56) and (57). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 7 * output_row_size, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<4, 4, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - // To process 4x4 outputs using a 3x3 filter, we require 6x6 inputs. - // Load inputs for the first 2 filters on the top left, then slide to - // the right, down, left, down, right, etc. in a snake-like path. This - // minimizes the total number of loads. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Now load 1x2 inputs on the top right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth + output_row_size, - output_depth); - - // Now load next inputs when sliding window left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + output_row_size, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_row_size, output_depth); - - // Now load next inputs when sliding window right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 2 * output_row_size, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 2 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 3 * output_row_size, output_depth); - - // Now load next inputs when sliding window left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 3 * output_row_size, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<4, 2, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Now load next inputs one row down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Now load next row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Now load last row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<4, 1, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 2x1 outputs starting from the top. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2yStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_row_size); - - // Load inputs for bottom 2 rows. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2yStride1( - filter, input_6, input_7, input_8, input_9, input_10, input_11, input_0, - input_1, input_2, input_3, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_row_size, - output_row_size); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<2, 2, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - Int32x8 acc_0, acc_1, acc_2, acc_3; - - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_2.low = vld1q_s32(bias_ptr); - acc_3.low = vld1q_s32(bias_ptr); - - bias_ptr += 4; - acc_0.high = vld1q_s32(bias_ptr); - acc_1.high = vld1q_s32(bias_ptr); - acc_2.high = vld1q_s32(bias_ptr); - acc_3.high = vld1q_s32(bias_ptr); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - // Add scope for input registers to help the compiler know that it is - // not needed. - { - // To process 2x2 outputs using a 3x3 filter, we require 4x4 inputs. - // Load inputs for the top two filters first. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - const uint8* ptr = input_ptr; - - // Load top 3 rows. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } +// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on +// Jetson TX-2. This compiler does not support the offsetof() macro. +#if defined(__aarch64__) && !defined(GOOGLE_L4T) - // Multiply-accum for top-left output. - acc_0 = MultiplyAccumulate3x3Filter(filter, input_0, input_1, input_2, - input_4, input_5, input_6, input_8, - input_9, input_10, acc_0); - - // Multiply-accum for top-right output. - acc_1 = MultiplyAccumulate3x3Filter(filter, input_1, input_2, input_3, - input_5, input_6, input_7, input_9, - input_10, input_11, acc_1); - - // Now load the bottom row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - // Multiply-accum for bottom-left output. - acc_2 = MultiplyAccumulate3x3Filter(filter, input_4, input_5, input_6, - input_8, input_9, input_10, input_0, - input_1, input_2, acc_2); - - // Multiply-accum for bottom-right output. - acc_3 = MultiplyAccumulate3x3Filter(filter, input_5, input_6, input_7, - input_9, input_10, input_11, input_1, - input_2, input_3, acc_3); - } - - DownquantizeAndStore2x2Output(acc_0, acc_1, acc_2, acc_3, output_offset, - output_multiplier, output_shift, - output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<2, 4, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } +// clang-format gets confused with this file and ends up formatting lines to +// be larger than 80 characters. Turn off here and back on at the end of the +// file. - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Now load 1x2 inputs on the top right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth + output_row_size, - output_depth); - - // Now load next inputs when sliding window left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } +// clang-format off - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + output_row_size, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<1, 4, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Now load 1x2 inputs on the right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_depth * 4; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } +#define DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE 10 * 10 * 64 - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - } +// Encapsulates constant parameters used in DepthwiseConv. +// 64-bit is used for types that will be added to 64-bit addresses in asm. +struct DepthwiseConvParams { + int64_t input_depth; + int64_t input_row_size; + int64_t output_depth; + int64_t output_row_size; + int64_t filter_row_size; + int32 input_offset; + int32 output_offset; + int32 filter_offset; + int32 output_multiplier; + int32 output_activation_min; + int32 output_activation_max; + int32 output_shift; + int32 input_width; + int32 input_height; + int32 stride_width; + int32 stride_height; + int32 output_width; + int32 output_height; }; -template <> -struct ConvKernel3x3FilterDepth8<2, 1, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - // To process 2x1 outputs using a 3x3 filter, we require 4x3 inputs. - // Load all inputs at the beginning. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. - { - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2yStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth * output_width); - } -}; +#define STR(s) STR_UNEXPANDED(s) +#define STR_UNEXPANDED(s) #s + +// Represents the number of bytes offset from the start of the +// DepthwiseConvParams struct. This is used in the asm to load parameters. +// Keep these values in sync with the static_asserts below. +#define OFFSET_INPUT_DEPTH 0 +#define OFFSET_INPUT_ROW_SIZE 8 +#define OFFSET_OUTPUT_DEPTH 16 +#define OFFSET_OUTPUT_ROW_SIZE 24 +#define OFFSET_FILTER_ROW_SIZE 32 +#define OFFSET_INPUT_OFFSET 40 +#define OFFSET_OUTPUT_OFFSET 44 +#define OFFSET_FILTER_OFFSET 48 +#define OFFSET_OUTPUT_MULTIPLIER 52 +#define OFFSET_OUTPUT_ACTIVATION_MIN 56 +#define OFFSET_OUTPUT_ACTIVATION_MAX 60 +#define OFFSET_OUTPUT_SHIFT 64 +#define OFFSET_INPUT_WIDTH 68 +#define OFFSET_INPUT_HEIGHT 72 +#define OFFSET_STRIDE_WIDTH 76 +#define OFFSET_STRIDE_HEIGHT 80 +#define OFFSET_OUTPUT_WIDTH 84 +#define OFFSET_OUTPUT_HEIGHT 88 + +static_assert(offsetof(DepthwiseConvParams, input_depth) == + OFFSET_INPUT_DEPTH, ""); +static_assert(offsetof(DepthwiseConvParams, input_row_size) == + OFFSET_INPUT_ROW_SIZE, ""); +static_assert(offsetof(DepthwiseConvParams, output_depth) == + OFFSET_OUTPUT_DEPTH, ""); +static_assert(offsetof(DepthwiseConvParams, output_row_size) == + OFFSET_OUTPUT_ROW_SIZE, ""); +static_assert(offsetof(DepthwiseConvParams, filter_row_size) == + OFFSET_FILTER_ROW_SIZE, ""); +static_assert(offsetof(DepthwiseConvParams, input_offset) == + OFFSET_INPUT_OFFSET, ""); +static_assert(offsetof(DepthwiseConvParams, output_offset) == + OFFSET_OUTPUT_OFFSET, ""); +static_assert(offsetof(DepthwiseConvParams, filter_offset) == + OFFSET_FILTER_OFFSET, ""); +static_assert(offsetof(DepthwiseConvParams, output_multiplier) == + OFFSET_OUTPUT_MULTIPLIER, ""); +static_assert(offsetof(DepthwiseConvParams, output_activation_min) == + OFFSET_OUTPUT_ACTIVATION_MIN, ""); +static_assert(offsetof(DepthwiseConvParams, output_activation_max) == + OFFSET_OUTPUT_ACTIVATION_MAX, ""); +static_assert(offsetof(DepthwiseConvParams, output_shift) == + OFFSET_OUTPUT_SHIFT, ""); +static_assert(offsetof(DepthwiseConvParams, input_width) == + OFFSET_INPUT_WIDTH, ""); +static_assert(offsetof(DepthwiseConvParams, input_height) == + OFFSET_INPUT_HEIGHT, ""); +static_assert(offsetof(DepthwiseConvParams, stride_width) == + OFFSET_STRIDE_WIDTH, ""); +static_assert(offsetof(DepthwiseConvParams, stride_height) == + OFFSET_STRIDE_HEIGHT, ""); +static_assert(offsetof(DepthwiseConvParams, output_width) == + OFFSET_OUTPUT_WIDTH, ""); +static_assert(offsetof(DepthwiseConvParams, output_height) == + OFFSET_OUTPUT_HEIGHT, ""); + +template +struct DepthwiseConvWindow {}; template <> -struct ConvKernel3x3FilterDepth8<4, 2, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - const int output_row_size = output_depth * output_width; - - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - Int32x8 acc_0, acc_1; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9; - - const uint8* ptr = input_ptr; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; - - // Load first 2 rows. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next 2 rows. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Moving onto the next row of outputs. - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next 2 rows. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Moving onto the next row of outputs. - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next 2 rows. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Moving onto the next row of outputs. - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load last row. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); +struct DepthwiseConvWindow<8, 1, 1> { + public: + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, int64_t input_depth, + int64_t input_row_size, int32 output_window_height, + int32 output_window_width, + const DepthwiseConvParams* params_ptr) { + const int64_t input_width_increment = 2 * input_depth; + const int64_t input_height_increment = 2 * input_row_size; + const int64_t output_height_increment = 2 * params_ptr->output_row_size; + +#define DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "1" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "2" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "3" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "4" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "5" +#define DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "6" +#define DEPTHWISECONV_LABEL_HEIGHT_1 "7" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "8" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "9" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "10" +#define DEPTHWISECONV_LABEL_HEIGHT_1_END "11" + + asm volatile( + // Performs depthwise convolutions for a window specified by + // |output_window_height| and |output_window_width|. The inner-most loop + // processes 2x2 outputs, and any leftovers at the end. + // + // Algorithm works as follows: + // + // 1. Load filters of 8 depth (8x3x3). Registers v0--v8 hold filter + // values. + // 2. For 2 output heights at a time: + // i. For 2 output widths at a time, load inputs for a 2x1 (2 + // height, 1 width) output window (4x3 input window). + // Registers v9--v20 hold input values. Mul-add with + // accumulators v21--v24. Then run activation, downquantize + // and store. Repeat for the next 2x1 output window, + // leveraging overlapping inputs. + // ii. Handle single leftover width if exists. + // 3. Handle single leftover height if exists. + // i. For 2 output widths at a time, load inputs for a 1x2 (1 + // height, 2 width) output window (3x4 input window). + // Registers v9--v20 hold input values. Mul-add with + // accumulators v21--v24. Then run activation, downquantize + // and store. Repeat for the next 1x2 output window, + // leveraging overlapping inputs. + // ii. Handle single leftover width if exists. + // + // Loads are placed as soon as the register is no longer needed and + // interleaved with arithmetic operations to take advantage of + // dual-issue pipelines. We also add input offsets as far from the loads + // as possible to give loads enough cycles to fetch data from memory. + + // Set "constant" registers. These registers may be replaced with temp + // values from time to time when there are not enough NEON registers. + // We use x9--x15 general purpose registers as they are caller-saved + // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT + "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr x3, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "cmp %w[output_window_height], #2\n" + "dup v26.8h, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v29.4s, w2\n" + "ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "dup v30.4s, w4\n" + "ldr w0, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v31.4s, w0\n" + "neg w9, w9\n" + "dup v28.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "add x10, %[bias_ptr], #16\n" + "ldr x1, [%[params_ptr], #" STR(OFFSET_OUTPUT_ROW_SIZE) "]\n" + "dup v9.8h, w9\n" + + // Load filters and add offsets. + "ld1 {v0.8b}, [%[filter_ptr]], x3\n" + "ld1 {v1.8b}, [%[filter_ptr]], x3\n" + "uaddw v0.8h, v9.8h, v0.8b\n" + "ld1 {v2.8b}, [%[filter_ptr]], x3\n" + "uaddw v1.8h, v9.8h, v1.8b\n" + "ld1 {v3.8b}, [%[filter_ptr]], x3\n" + "uaddw v2.8h, v9.8h, v2.8b\n" + "ld1 {v4.8b}, [%[filter_ptr]], x3\n" + "uaddw v3.8h, v9.8h, v3.8b\n" + "ld1 {v5.8b}, [%[filter_ptr]], x3\n" + "uaddw v4.8h, v9.8h, v4.8b\n" + "ld1 {v6.8b}, [%[filter_ptr]], x3\n" + "uaddw v5.8h, v9.8h, v5.8b\n" + "ld1 {v7.8b}, [%[filter_ptr]], x3\n" + "uaddw v6.8h, v9.8h, v6.8b\n" + "ld1 {v8.8b}, [%[filter_ptr]], x3\n" + "uaddw v7.8h, v9.8h, v7.8b\n" + "uaddw v8.8h, v9.8h, v8.8b\n" + + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_LOOP ":\n" + // This loop processes 2x2 outputs. To avoid register exhaustion, + // inputs for the left 2 outputs are loaded first, then the right + // two outputs. + "mov x11, %[input_ptr]\n" + "mov x12, x11\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "add x13, x11, %[input_row_size]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "add x14, x13, %[input_row_size]\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x15, x14, %[input_row_size]\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "mov w5, %w[output_window_width]\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "add x7, %[output_ptr], x1\n" + "ld1 {v15.8b}, [x14], %[input_depth]\n" + // The height 2 / width 2 loop loads an extra 2x1 outputs (2 height, + // 1 width) in anticipation for the next iteration. Make sure + // |output_window_width| is large enough to handle the additional + // loads, otherwise jump to specific the appropriate label to handle + // smaller widths. + "cmp w5, #2\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v16.8b}, [x14], %[input_depth]\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "ld1 {v18.8b}, [x15], %[input_depth]\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "ld1 {v19.8b}, [x15], %[input_depth]\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "ld1 {v20.8b}, [x15], %[input_depth]\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v22.4s}, [x10]\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "ld1 {v24.4s}, [x10]\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "f\n" + "cmp w5, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP ":\n" + // Mul-add left outputs. + "smlal v21.4s, v0.4h, v9.4h\n" + "subs w5, w5, #2\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "cmp w5, #3\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "ld1 {v9.8b}, [x12]\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v14.4h\n" + "smlal2 v24.4s, v2.8h, v14.8h\n" + "smlal v21.4s, v3.4h, v12.4h\n" + "smlal2 v22.4s, v3.8h, v12.8h\n" + "ld1 {v12.8b}, [x13]\n" + "smlal v23.4s, v3.4h, v15.4h\n" + "smlal2 v24.4s, v3.8h, v15.8h\n" + "smlal v21.4s, v4.4h, v13.4h\n" + "smlal2 v22.4s, v4.8h, v13.8h\n" + "smlal v23.4s, v4.4h, v16.4h\n" + "smlal2 v24.4s, v4.8h, v16.8h\n" + "smlal v21.4s, v5.4h, v14.4h\n" + "smlal2 v22.4s, v5.8h, v14.8h\n" + "smlal v23.4s, v5.4h, v17.4h\n" + "smlal2 v24.4s, v5.8h, v17.8h\n" + "smlal v21.4s, v6.4h, v15.4h\n" + "smlal2 v22.4s, v6.8h, v15.8h\n" + "ld1 {v15.8b}, [x14]\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "ld1 {v18.8b}, [x15]\n" + "smlal v21.4s, v7.4h, v16.4h\n" + "smlal2 v22.4s, v7.8h, v16.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x3\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "st1 {v23.8b}, [x7], x3\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + + // Mul-add right outputs. + "smlal v21.4s, v0.4h, v10.4h\n" + "add x11, x11, %[input_width_increment]\n" + "smlal2 v22.4s, v0.8h, v10.8h\n" + "mov x12, x11\n" + "smlal v23.4s, v0.4h, v13.4h\n" + "add x13, x11, %[input_row_size]\n" + "smlal2 v24.4s, v0.8h, v13.8h\n" + "add x14, x13, %[input_row_size]\n" + "smlal v21.4s, v1.4h, v11.4h\n" + "add x15, x14, %[input_row_size]\n" + "smlal2 v22.4s, v1.8h, v11.8h\n" + "smlal v23.4s, v1.4h, v14.4h\n" + "smlal2 v24.4s, v1.8h, v14.8h\n" + "smlal v21.4s, v2.4h, v9.4h\n" + "smlal2 v22.4s, v2.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "smlal v21.4s, v5.4h, v12.4h\n" + "smlal2 v22.4s, v5.8h, v12.8h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v5.4h, v15.4h\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v5.8h, v15.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v21.4s, v6.4h, v16.4h\n" + "smlal2 v22.4s, v6.8h, v16.8h\n" + "smlal v23.4s, v6.4h, v19.4h\n" + "smlal2 v24.4s, v6.8h, v19.8h\n" + "smlal v21.4s, v7.4h, v17.4h\n" + "smlal2 v22.4s, v7.8h, v17.8h\n" + "smlal v23.4s, v7.4h, v20.4h\n" + "smlal2 v24.4s, v7.8h, v20.8h\n" + "smlal v21.4s, v8.4h, v15.4h\n" + "smlal2 v22.4s, v8.8h, v15.8h\n" + "ld1 {v15.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v8.4h, v18.4h\n" + "ld1 {v16.8b}, [x14], %[input_depth]\n" + "smlal2 v24.4s, v8.8h, v18.8h\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "ld1 {v18.8b}, [x15], %[input_depth]\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "ld1 {v19.8b}, [x15], %[input_depth]\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "ld1 {v20.8b}, [x15], %[input_depth]\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x3\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "st1 {v23.8b}, [x7], x3\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w5, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + // Handle last 2 columns if exists. + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER ":\n" + // Mul-add left outputs. + "smlal v21.4s, v0.4h, v9.4h\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "ld1 {v9.8b}, [x12]\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v14.4h\n" + "smlal2 v24.4s, v2.8h, v14.8h\n" + "smlal v21.4s, v3.4h, v12.4h\n" + "smlal2 v22.4s, v3.8h, v12.8h\n" + "ld1 {v12.8b}, [x13]\n" + "smlal v23.4s, v3.4h, v15.4h\n" + "smlal2 v24.4s, v3.8h, v15.8h\n" + "smlal v21.4s, v4.4h, v13.4h\n" + "smlal2 v22.4s, v4.8h, v13.8h\n" + "smlal v23.4s, v4.4h, v16.4h\n" + "smlal2 v24.4s, v4.8h, v16.8h\n" + "smlal v21.4s, v5.4h, v14.4h\n" + "smlal2 v22.4s, v5.8h, v14.8h\n" + "smlal v23.4s, v5.4h, v17.4h\n" + "smlal2 v24.4s, v5.8h, v17.8h\n" + "smlal v21.4s, v6.4h, v15.4h\n" + "smlal2 v22.4s, v6.8h, v15.8h\n" + "ld1 {v15.8b}, [x14]\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "ld1 {v18.8b}, [x15]\n" + "smlal v21.4s, v7.4h, v16.4h\n" + "smlal2 v22.4s, v7.8h, v16.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x3\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "st1 {v23.8b}, [x7], x3\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + + // Mul-add right outputs. + "smlal v21.4s, v0.4h, v10.4h\n" + "smlal2 v22.4s, v0.8h, v10.8h\n" + "smlal v23.4s, v0.4h, v13.4h\n" + "smlal2 v24.4s, v0.8h, v13.8h\n" + "smlal v21.4s, v1.4h, v11.4h\n" + "smlal2 v22.4s, v1.8h, v11.8h\n" + "smlal v23.4s, v1.4h, v14.4h\n" + "smlal2 v24.4s, v1.8h, v14.8h\n" + "smlal v21.4s, v2.4h, v9.4h\n" + "smlal2 v22.4s, v2.8h, v9.8h\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "smlal v21.4s, v5.4h, v12.4h\n" + "smlal2 v22.4s, v5.8h, v12.8h\n" + "smlal v23.4s, v5.4h, v15.4h\n" + "smlal2 v24.4s, v5.8h, v15.8h\n" + "smlal v21.4s, v6.4h, v16.4h\n" + "smlal2 v22.4s, v6.8h, v16.8h\n" + "smlal v23.4s, v6.4h, v19.4h\n" + "smlal2 v24.4s, v6.8h, v19.8h\n" + "smlal v21.4s, v7.4h, v17.4h\n" + "smlal2 v22.4s, v7.8h, v17.8h\n" + "smlal v23.4s, v7.4h, v20.4h\n" + "smlal2 v24.4s, v7.8h, v20.8h\n" + "smlal v21.4s, v8.4h, v15.4h\n" + "smlal2 v22.4s, v8.8h, v15.8h\n" + "smlal v23.4s, v8.4h, v18.4h\n" + "smlal2 v24.4s, v8.8h, v18.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v21.8b}, [x6], x3\n" + "st1 {v23.8b}, [x7], x3\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "f\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v14.4h\n" + "smlal2 v24.4s, v2.8h, v14.8h\n" + "smlal v21.4s, v3.4h, v12.4h\n" + "smlal2 v22.4s, v3.8h, v12.8h\n" + "smlal v23.4s, v3.4h, v15.4h\n" + "smlal2 v24.4s, v3.8h, v15.8h\n" + "smlal v21.4s, v4.4h, v13.4h\n" + "smlal2 v22.4s, v4.8h, v13.8h\n" + "smlal v23.4s, v4.4h, v16.4h\n" + "smlal2 v24.4s, v4.8h, v16.8h\n" + "smlal v21.4s, v5.4h, v14.4h\n" + "smlal2 v22.4s, v5.8h, v14.8h\n" + "smlal v23.4s, v5.4h, v17.4h\n" + "smlal2 v24.4s, v5.8h, v17.8h\n" + "smlal v21.4s, v6.4h, v15.4h\n" + "smlal2 v22.4s, v6.8h, v15.8h\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "smlal v21.4s, v7.4h, v16.4h\n" + "smlal2 v22.4s, v7.8h, v16.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v9.16b, v21.16b, v28.16b\n" + "and v12.16b, v22.16b, v28.16b\n" + "and v15.16b, v23.16b, v28.16b\n" + "and v18.16b, v24.16b, v28.16b\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v12.4s, v12.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sshr v18.4s, v18.4s, #31\n" + "sqadd v21.4s, v21.4s, v9.4s\n" + "sqadd v22.4s, v22.4s, v12.4s\n" + "sqadd v23.4s, v23.4s, v15.4s\n" + "sqadd v24.4s, v24.4s, v18.4s\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v21.8b}, [x6], x3\n" + "st1 {v23.8b}, [x7], x3\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP ":\n" + "subs %w[output_window_height], %w[output_window_height], #2\n" + "add %[input_ptr], %[input_ptr], %[input_height_increment]\n" + "cmp %w[output_window_height], #2\n" + "add %[output_ptr], %[output_ptr], %[output_height_increment]\n" + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "b\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP ":\n" + "cmp %w[output_window_height], #1\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + DEPTHWISECONV_LABEL_HEIGHT_1 ":\n" + "mov x12, %[input_ptr]\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "add x13, %[input_ptr], %[input_row_size]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "add x14, x13, %[input_row_size]\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x15, x14, %[input_row_size]\n" + "mov w5, %w[output_window_width]\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "add x7, %[output_ptr], x1\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + // The height 1 / width 2 loop loads an extra 1x1 output in anticipation + // for the next iteration. Make sure |output_window_width| is large + // enough to handle the additional load, otherwise jump to the + // appropriate label to handle smaller widths. + "cmp w5, #2\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + "ld1 {v18.8b}, [x14], %[input_depth]\n" + "ld1 {v19.8b}, [x14], %[input_depth]\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "ld1 {v22.4s}, [x10]\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "ld1 {v24.4s}, [x10]\n" + + "uaddw v9.8h, v26.8h, v9.8b\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "f\n" + "cmp w5, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP ":\n" + // Load inputs for 3x4 input window which corresponds to a 1x2 output + // window. + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v16.8b}, [x13]\n" + "smlal v23.4s, v0.4h, v10.4h\n" + "ld1 {v20.8b}, [x14]\n" + "smlal2 v24.4s, v0.8h, v10.8h\n" + "subs w5, w5, #2\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "cmp w5, #3\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "add %[input_ptr], %[input_ptr], %[input_width_increment]\n" + "smlal v23.4s, v1.4h, v11.4h\n" + "mov x12, %[input_ptr]\n" + "smlal2 v24.4s, v1.8h, v11.8h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x13, %[input_ptr], %[input_row_size]\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "add x14, x13, %[input_row_size]\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "add x15, x14, %[input_row_size]\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v3.4h, v14.4h\n" + "smlal2 v24.4s, v3.8h, v14.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v4.4h, v15.4h\n" + "smlal2 v24.4s, v4.8h, v15.8h\n" + "smlal v21.4s, v5.4h, v15.4h\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "smlal2 v22.4s, v5.8h, v15.8h\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v5.4h, v16.4h\n" + "smlal2 v24.4s, v5.8h, v16.8h\n" + "smlal v21.4s, v6.4h, v17.4h\n" + "smlal2 v22.4s, v6.8h, v17.8h\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "smlal v21.4s, v7.4h, v18.4h\n" + "smlal2 v22.4s, v7.8h, v18.8h\n" + "ld1 {v18.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v19.4h\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + "smlal2 v22.4s, v8.8h, v19.8h\n" + "ld1 {v19.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [%[output_ptr]], x3\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "st1 {v23.8b}, [%[output_ptr]], x3\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w5, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + // Handle last two horizontal outputs if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v16.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v0.4h, v10.4h\n" + "ld1 {v20.8b}, [x14], %[input_depth]\n" + "smlal2 v24.4s, v0.8h, v10.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v11.4h\n" + "smlal2 v24.4s, v1.8h, v11.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v23.4s, v3.4h, v14.4h\n" + "smlal2 v24.4s, v3.8h, v14.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v23.4s, v4.4h, v15.4h\n" + "smlal2 v24.4s, v4.8h, v15.8h\n" + "smlal v21.4s, v5.4h, v15.4h\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "smlal2 v22.4s, v5.8h, v15.8h\n" + "smlal v23.4s, v5.4h, v16.4h\n" + "smlal2 v24.4s, v5.8h, v16.8h\n" + "smlal v21.4s, v6.4h, v17.4h\n" + "smlal2 v22.4s, v6.8h, v17.8h\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "smlal v21.4s, v7.4h, v18.4h\n" + "smlal2 v22.4s, v7.8h, v18.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v19.4h\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + "smlal2 v22.4s, v8.8h, v19.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v21.8b}, [%[output_ptr]], x3\n" + "st1 {v23.8b}, [%[output_ptr]], x3\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + // Handle bottom right output if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v21.4s, v5.4h, v15.4h\n" + "smlal2 v22.4s, v5.8h, v15.8h\n" + "smlal v21.4s, v6.4h, v17.4h\n" + "smlal2 v22.4s, v6.8h, v17.8h\n" + "smlal v21.4s, v7.4h, v18.4h\n" + "smlal2 v22.4s, v7.8h, v18.8h\n" + "smlal v21.4s, v8.4h, v19.4h\n" + "smlal2 v22.4s, v8.8h, v19.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "and v9.16b, v21.16b, v28.16b\n" + "and v12.16b, v22.16b, v28.16b\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v12.4s, v12.4s, #31\n" + "sqadd v21.4s, v21.4s, v9.4s\n" + "sqadd v22.4s, v22.4s, v12.4s\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtun v21.8b, v21.8h\n" + "st1 {v21.8b}, [%[output_ptr]]\n" + DEPTHWISECONV_LABEL_HEIGHT_1_END ":\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), + [output_window_height] "+r"(output_window_height) + : + // Inputs. + [bias_ptr] "r"(bias_ptr), [input_row_size] "r"(input_row_size), + [input_depth] "r"(input_depth), + [output_window_width] "r"(output_window_width), + [input_width_increment] "r"(input_width_increment), + [input_height_increment] "r"(input_height_increment), + [output_height_increment] "r"(output_height_increment), + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_HEIGHT_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1 +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_END } }; template <> -struct ConvKernel3x3FilterDepth8<4, 4, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - // Reuse 4x2 kernel twice. - ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth, - output_width); - - ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run( - input_ptr + 4 * input_depth, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr + 2 * output_depth, output_depth, output_width); +struct DepthwiseConvWindow<8, 2, 2> { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, int64_t input_depth, + int64_t input_row_size, int32 output_window_height, + int32 output_window_width, + const DepthwiseConvParams* params_ptr) { + const int64_t input_width_increment = 4 * input_depth; + const int64_t input_height_increment = 4 * input_row_size; + const int64_t output_height_increment = 2 * params_ptr->output_row_size; + +#define DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "1" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "2" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "3" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "4" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "5" +#define DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "6" +#define DEPTHWISECONV_LABEL_HEIGHT_1 "7" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "8" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "9" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "10" +#define DEPTHWISECONV_LABEL_HEIGHT_1_END "11" + + asm volatile( + // Performs depthwise convolutions for a window specified by + // |output_window_height| and |output_window_width|. The inner-most loop + // processes 2x2 outputs, and any leftovers at the end. + // + // Algorithm works as follows: + // + // 1. Load filters of 8 depth (8x3x3). Registers v0--v8 hold filter + // values. + // 2. For 2 output heights at a time: + // i. For 2 output widths at a time at stride 2, a 5x5 input + // window is required. To avoid register exhaustion, we load + // the first 2 rows of the 5x5 input window into registers + // v9--v18, and use the same registers to load the next 2 + // rows, and finally v9--v13 to load the last row. + // Accumulators for all 2x2 outputs are reserved by registers + // v21-v22 (top left output), v23-v24 (top right output), + // v19-v20 (bottom left output), v25-v26 (bottom right + // output). + // ii. Handle single leftover width if exists. + // 3. Handle single leftover height if exists. + // i. For 2 output widths at a time at stride 2, load inputs for + // a 1x2 (1 height, 2 width) output window (3x5 input + // window). Registers v9--v24 hold input values. Mul-add with + // accumulators v24--v27. + // ii. Handle single leftover width if exists. + // + // Loads are placed as soon as the register is no longer needed and + // interleaved with arithmetic operations to take advantage of + // dual-issue pipelines. We also add input offsets as far from the loads + // as possible to give loads enough cycles to fetch data from memory. + + // Set "constant" registers. These registers may be replaced with temp + // values from time to time when there are not enough NEON registers. + // We use x9--x15 general purpose registers as they are caller-saved + // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "ldr w0, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "cmp %w[output_window_height], #2\n" + "dup v28.8h, w0\n" + "neg w9, w9\n" + "ldr w1, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.4s, w9\n" + "ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w1\n" + "ldr w3, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "dup v29.4s, w2\n" + "ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w3\n" + "ldr x5, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "dup v31.4s, w4\n" + "ldr x19, [%[params_ptr], #" STR(OFFSET_OUTPUT_ROW_SIZE) "]\n" + "ldr w20, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + + // Load filters and add offsets. + "add x10, %[bias_ptr], #16\n" + "ld1 {v0.8b}, [%[filter_ptr]], x5\n" + "dup v9.8h, w20\n" + "ld1 {v1.8b}, [%[filter_ptr]], x5\n" + "uaddw v0.8h, v9.8h, v0.8b\n" + "ld1 {v2.8b}, [%[filter_ptr]], x5\n" + "uaddw v1.8h, v9.8h, v1.8b\n" + "ld1 {v3.8b}, [%[filter_ptr]], x5\n" + "uaddw v2.8h, v9.8h, v2.8b\n" + "ld1 {v4.8b}, [%[filter_ptr]], x5\n" + "uaddw v3.8h, v9.8h, v3.8b\n" + "ld1 {v5.8b}, [%[filter_ptr]], x5\n" + "uaddw v4.8h, v9.8h, v4.8b\n" + "ld1 {v6.8b}, [%[filter_ptr]], x5\n" + "uaddw v5.8h, v9.8h, v5.8b\n" + "ld1 {v7.8b}, [%[filter_ptr]], x5\n" + "uaddw v6.8h, v9.8h, v6.8b\n" + "ld1 {v8.8b}, [%[filter_ptr]]\n" + "uaddw v7.8h, v9.8h, v7.8b\n" + "uaddw v8.8h, v9.8h, v8.8b\n" + + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_LOOP ":\n" + // Load the first two rows of the 5x5 input window, then reuse the + // same registers to load subsequent rows as they become available. + "mov x11, %[input_ptr]\n" + "mov x12, x11\n" + "add x13, x12, %[input_row_size]\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "mov w14, %w[output_window_width]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + // The height 2 / width 2 loop loads an extra 1 output horizontally in + // anticipation for the next iteration. Make sure + // |output_window_width| is large enough to handle the additional + // load, otherwise jump to the appropriate label to handle smaller + // widths. + "cmp w14, #2\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x15, x13, %[input_row_size]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "add x7, %[output_ptr], x19\n" + "ld1 {v16.8b}, [x13], %[input_depth]\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "ld1 {v22.4s}, [x10]\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "ld1 {v24.4s}, [x10]\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "ld1 {v19.4s}, [%[bias_ptr]]\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "ld1 {v20.4s}, [x10]\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "ld1 {v25.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "ld1 {v26.4s}, [x10]\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "f\n" + "cmp w14, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v13.8b}, [x12]\n" + "add x12, x15, %[input_row_size]\n" + "smlal v23.4s, v0.4h, v11.4h\n" + "ld1 {v17.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v0.8h, v11.8h\n" + "ld1 {v18.8b}, [x13]\n" + "add x13, x12, %[input_row_size]\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "ld1 {v9.8b}, [x15], %[input_depth]\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v3.4h, v14.4h\n" + "smlal2 v22.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "subs w14, w14, #2\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "cmp w14, #3\n" + "smlal v21.4s, v4.4h, v15.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v22.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v5.4h, v16.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v22.4s, v5.8h, v16.8h\n" + "ld1 {v16.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v1.4h, v12.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v24.4s, v1.8h, v12.8h\n" + "ld1 {v12.8b}, [x15], %[input_depth]\n" + "smlal v23.4s, v2.4h, v13.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v24.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x15]\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "ld1 {v17.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v5.4h, v18.4h\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "smlal2 v24.4s, v5.8h, v18.8h\n" + "ld1 {v18.8b}, [x12]\n" + + "smlal v21.4s, v6.4h, v9.4h\n" + "smlal2 v22.4s, v6.8h, v9.8h\n" + "smlal v19.4s, v0.4h, v9.4h\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "smlal2 v20.4s, v0.8h, v9.8h\n" + "ld1 {v9.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v6.4h, v11.4h\n" + "smlal2 v24.4s, v6.8h, v11.8h\n" + "smlal v21.4s, v7.4h, v10.4h\n" + "smlal2 v22.4s, v7.8h, v10.8h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal v19.4s, v1.4h, v10.4h\n" + "smlal2 v20.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v7.4h, v12.4h\n" + "smlal2 v24.4s, v7.8h, v12.8h\n" + "smlal v25.4s, v1.4h, v12.4h\n" + "smlal2 v26.4s, v1.8h, v12.8h\n" + "smlal v21.4s, v8.4h, v11.4h\n" + "smlal2 v22.4s, v8.8h, v11.8h\n" + "add x11, x11, %[input_width_increment]\n" + "smlal v19.4s, v2.4h, v11.4h\n" + "mov x12, x11\n" + "smlal2 v20.4s, v2.8h, v11.8h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal v25.4s, v0.4h, v11.4h\n" + "smlal2 v26.4s, v0.8h, v11.8h\n" + "ld1 {v11.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v8.4h, v13.4h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v8.8h, v13.8h\n" + "smlal v25.4s, v2.4h, v13.4h\n" + "smlal2 v26.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x13]\n" + "add x13, x12, %[input_row_size]\n" + "add x15, x13, %[input_row_size]\n" + + "dup v28.4s, w9\n" + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v27.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v23.8b}, [x6], x5\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + + "smlal v19.4s, v6.4h, v9.4h\n" + "smlal2 v20.4s, v6.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal v25.4s, v6.4h, v11.4h\n" + "smlal2 v26.4s, v6.8h, v11.8h\n" + "smlal v19.4s, v7.4h, v10.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v20.4s, v7.8h, v10.8h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal v25.4s, v7.4h, v12.4h\n" + "smlal2 v26.4s, v7.8h, v12.8h\n" + "smlal v19.4s, v8.4h, v11.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v20.4s, v8.8h, v11.8h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "smlal v25.4s, v8.4h, v13.4h\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "smlal2 v26.4s, v8.8h, v13.8h\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "smlal v19.4s, v3.4h, v14.4h\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "smlal2 v20.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v25.4s, v3.4h, v16.4h\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "smlal2 v26.4s, v3.8h, v16.8h\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "smlal v19.4s, v4.4h, v15.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v20.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "smlal v25.4s, v4.4h, v17.4h\n" + "smlal2 v26.4s, v4.8h, v17.8h\n" + "smlal v19.4s, v5.4h, v16.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v20.4s, v5.8h, v16.8h\n" + "ld1 {v16.8b}, [x13], %[input_depth]\n" + "smlal v25.4s, v5.4h, v18.4h\n" + "smlal2 v26.4s, v5.8h, v18.8h\n" + + "dup v28.4s, w9\n" + "sqrdmulh v19.4s, v19.4s, v27.4s\n" + "sqrdmulh v20.4s, v20.4s, v27.4s\n" + "sqrdmulh v25.4s, v25.4s, v27.4s\n" + "sqrdmulh v26.4s, v26.4s, v27.4s\n" + "and v27.16b, v19.16b, v28.16b\n" + "and v29.16b, v20.16b, v28.16b\n" + "and v30.16b, v25.16b, v28.16b\n" + "and v31.16b, v26.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v19.4s, v19.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v20.4s, v20.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v25.4s, v25.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v26.4s, v26.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v19.4s, v19.4s, v28.4s\n" + "srshl v20.4s, v20.4s, v28.4s\n" + "srshl v25.4s, v25.4s, v28.4s\n" + "srshl v26.4s, v26.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v19.4s, v19.4s, v29.4s\n" + "add v20.4s, v20.4s, v29.4s\n" + "add v25.4s, v25.4s, v29.4s\n" + "add v26.4s, v26.4s, v29.4s\n" + "smax v19.4s, v19.4s, v30.4s\n" + "smax v20.4s, v20.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smin v19.4s, v19.4s, v31.4s\n" + "smin v20.4s, v20.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "sqxtn v19.4h, v19.4s\n" + "sqxtn v25.4h, v25.4s\n" + "sqxtn2 v19.8h, v20.4s\n" + "ld1 {v20.4s}, [x10]\n" + "sqxtn2 v25.8h, v26.4s\n" + "ld1 {v26.4s}, [x10]\n" + "sqxtun v19.8b, v19.8h\n" + "sqxtun v25.8b, v25.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v19.8b}, [x7], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v25.8b}, [x7], x5\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "ld1 {v19.4s}, [%[bias_ptr]]\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "ld1 {v25.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w14, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + // Handle last 2 columns if exists. + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v13.8b}, [x12]\n" + "add x12, x15, %[input_row_size]\n" + "smlal v23.4s, v0.4h, v11.4h\n" + "ld1 {v17.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v0.8h, v11.8h\n" + "ld1 {v18.8b}, [x13]\n" + "add x13, x12, %[input_row_size]\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "ld1 {v9.8b}, [x15], %[input_depth]\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v3.4h, v14.4h\n" + "smlal2 v22.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "smlal v21.4s, v4.4h, v15.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v22.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v5.4h, v16.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v22.4s, v5.8h, v16.8h\n" + "ld1 {v16.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v1.4h, v12.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v24.4s, v1.8h, v12.8h\n" + "ld1 {v12.8b}, [x15], %[input_depth]\n" + "smlal v23.4s, v2.4h, v13.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v24.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x15]\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "ld1 {v17.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v5.4h, v18.4h\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "smlal2 v24.4s, v5.8h, v18.8h\n" + "ld1 {v18.8b}, [x12]\n" + + "smlal v21.4s, v6.4h, v9.4h\n" + "smlal2 v22.4s, v6.8h, v9.8h\n" + "smlal v19.4s, v0.4h, v9.4h\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "smlal2 v20.4s, v0.8h, v9.8h\n" + "ld1 {v9.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v6.4h, v11.4h\n" + "smlal2 v24.4s, v6.8h, v11.8h\n" + "smlal v21.4s, v7.4h, v10.4h\n" + "smlal2 v22.4s, v7.8h, v10.8h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal v19.4s, v1.4h, v10.4h\n" + "smlal2 v20.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v7.4h, v12.4h\n" + "smlal2 v24.4s, v7.8h, v12.8h\n" + "smlal v25.4s, v1.4h, v12.4h\n" + "smlal2 v26.4s, v1.8h, v12.8h\n" + "smlal v21.4s, v8.4h, v11.4h\n" + "smlal2 v22.4s, v8.8h, v11.8h\n" + "smlal v19.4s, v2.4h, v11.4h\n" + "smlal2 v20.4s, v2.8h, v11.8h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal v25.4s, v0.4h, v11.4h\n" + "smlal2 v26.4s, v0.8h, v11.8h\n" + "ld1 {v11.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v8.4h, v13.4h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v8.8h, v13.8h\n" + "smlal v25.4s, v2.4h, v13.4h\n" + "smlal2 v26.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x13]\n" + + "dup v28.4s, w9\n" + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v27.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v23.8b}, [x6]\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + + "smlal v19.4s, v6.4h, v9.4h\n" + "smlal2 v20.4s, v6.8h, v9.8h\n" + "smlal v25.4s, v6.4h, v11.4h\n" + "smlal2 v26.4s, v6.8h, v11.8h\n" + "smlal v19.4s, v7.4h, v10.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v20.4s, v7.8h, v10.8h\n" + "smlal v25.4s, v7.4h, v12.4h\n" + "smlal2 v26.4s, v7.8h, v12.8h\n" + "smlal v19.4s, v8.4h, v11.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v20.4s, v8.8h, v11.8h\n" + "smlal v25.4s, v8.4h, v13.4h\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "smlal2 v26.4s, v8.8h, v13.8h\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "smlal v19.4s, v3.4h, v14.4h\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "smlal2 v20.4s, v3.8h, v14.8h\n" + "smlal v25.4s, v3.4h, v16.4h\n" + "smlal2 v26.4s, v3.8h, v16.8h\n" + "smlal v19.4s, v4.4h, v15.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v20.4s, v4.8h, v15.8h\n" + "smlal v25.4s, v4.4h, v17.4h\n" + "smlal2 v26.4s, v4.8h, v17.8h\n" + "smlal v19.4s, v5.4h, v16.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v20.4s, v5.8h, v16.8h\n" + "smlal v25.4s, v5.4h, v18.4h\n" + "smlal2 v26.4s, v5.8h, v18.8h\n" + + "dup v28.4s, w9\n" + "sqrdmulh v19.4s, v19.4s, v27.4s\n" + "sqrdmulh v20.4s, v20.4s, v27.4s\n" + "sqrdmulh v25.4s, v25.4s, v27.4s\n" + "sqrdmulh v26.4s, v26.4s, v27.4s\n" + "and v27.16b, v19.16b, v28.16b\n" + "and v29.16b, v20.16b, v28.16b\n" + "and v30.16b, v25.16b, v28.16b\n" + "and v31.16b, v26.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v19.4s, v19.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v20.4s, v20.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v25.4s, v25.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v26.4s, v26.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v19.4s, v19.4s, v28.4s\n" + "srshl v20.4s, v20.4s, v28.4s\n" + "srshl v25.4s, v25.4s, v28.4s\n" + "srshl v26.4s, v26.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v19.4s, v19.4s, v29.4s\n" + "add v20.4s, v20.4s, v29.4s\n" + "add v25.4s, v25.4s, v29.4s\n" + "add v26.4s, v26.4s, v29.4s\n" + "smax v19.4s, v19.4s, v30.4s\n" + "smax v20.4s, v20.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smin v19.4s, v19.4s, v31.4s\n" + "smin v20.4s, v20.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "sqxtn v19.4h, v19.4s\n" + "sqxtn v25.4h, v25.4s\n" + "sqxtn2 v19.8h, v20.4s\n" + "sqxtn2 v25.8h, v26.4s\n" + "sqxtun v19.8b, v19.8h\n" + "sqxtun v25.8b, v25.8h\n" + "st1 {v19.8b}, [x7], x5\n" + "st1 {v25.8b}, [x7]\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "f\n" + + // Handle last column if exists. + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER ":\n" + // Registers v9, v10, v11, v14, v15, and v16 have already been loaded + // with the correct values at this point. This corresponds to the + // first two input rows of the top left output. Now load the last + // input row for this output. Once these inputs are no longer needed, + // load the input rows for the bottom left output. + "add x12, x15, %[input_row_size]\n" + "add x13, x12, %[input_row_size]\n" + + "ld1 {v12.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v13.8b}, [x15], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v17.8b}, [x15]\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x12]\n" + "smlal v21.4s, v3.4h, v14.4h\n" + "smlal2 v22.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v21.4s, v4.4h, v15.4h\n" + "smlal2 v22.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "smlal v21.4s, v5.4h, v16.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v22.4s, v5.8h, v16.8h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "ld1 {v16.8b}, [x13]\n" + + "smlal v21.4s, v6.4h, v12.4h\n" + "smlal2 v22.4s, v6.8h, v12.8h\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v7.4h, v13.4h\n" + "smlal2 v22.4s, v7.8h, v13.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v2.4h, v17.4h\n" + "smlal2 v24.4s, v2.8h, v17.8h\n" + + "dup v26.4s, w9\n" + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "and v18.16b, v21.16b, v26.16b\n" + "and v19.16b, v22.16b, v26.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v21.4s, v21.4s, v18.4s\n" + "sqadd v22.4s, v22.4s, v19.4s\n" + "srshl v21.4s, v21.4s, v26.4s\n" + "srshl v22.4s, v22.4s, v26.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtun v21.8b, v21.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v21.8b}, [x6]\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + + "smlal v23.4s, v3.4h, v9.4h\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "smlal2 v24.4s, v3.8h, v9.8h\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "smlal v23.4s, v4.4h, v10.4h\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "smlal2 v24.4s, v4.8h, v10.8h\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "smlal v23.4s, v5.4h, v11.4h\n" + "smlal2 v24.4s, v5.8h, v11.8h\n" + + "smlal v23.4s, v6.4h, v14.4h\n" + "smlal2 v24.4s, v6.8h, v14.8h\n" + "smlal v23.4s, v7.4h, v15.4h\n" + "smlal2 v24.4s, v7.8h, v15.8h\n" + "smlal v23.4s, v8.4h, v16.4h\n" + "smlal2 v24.4s, v8.8h, v16.8h\n" + + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v18.16b, v23.16b, v26.16b\n" + "and v19.16b, v24.16b, v26.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v23.4s, v23.4s, v18.4s\n" + "sqadd v24.4s, v24.4s, v19.4s\n" + "srshl v23.4s, v23.4s, v26.4s\n" + "srshl v24.4s, v24.4s, v26.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v23.8b}, [x7]\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP ":\n" + "subs %w[output_window_height], %w[output_window_height], #2\n" + "add %[input_ptr], %[input_ptr], %[input_height_increment]\n" + "cmp %w[output_window_height], #2\n" + "add %[output_ptr], %[output_ptr], %[output_height_increment]\n" + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "b\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP ":\n" + "cmp %w[output_window_height], #1\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + DEPTHWISECONV_LABEL_HEIGHT_1 ":\n" + "mov x11, %[input_ptr]\n" + "mov x12, x11\n" + "add x13, x12, %[input_row_size]\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "add x15, x13, %[input_row_size]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "mov w14, %w[output_window_width]\n" + // The height 1 / width 2 loop loads an extra 1x1 output in anticipation + // for the next iteration. Make sure |output_window_width| is large + // enough to handle the additional load, otherwise jump to the + // appropriate label to handle smaller widths. + "cmp w14, #2\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "ld1 {v15.8b}, [x15], %[input_depth]\n" + "ld1 {v16.8b}, [x15], %[input_depth]\n" + "ld1 {v17.8b}, [x15], %[input_depth]\n" + + "uaddw v9.8h, v28.8h, v9.8b\n" + "ld1 {v24.4s}, [%[bias_ptr]]\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "ld1 {v25.4s}, [x10]\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "ld1 {v26.4s}, [%[bias_ptr]]\n" + "ld1 {v27.4s}, [x10]\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "f\n" + "cmp w14, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP ":\n" + "smlal v24.4s, v0.4h, v9.4h\n" + "ld1 {v18.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v0.8h, v9.8h\n" + "ld1 {v19.8b}, [x12]\n" + "smlal v26.4s, v0.4h, v11.4h\n" + "ld1 {v20.8b}, [x13], %[input_depth]\n" + "smlal2 v27.4s, v0.8h, v11.8h\n" + "ld1 {v21.8b}, [x13]\n" + "smlal v24.4s, v1.4h, v10.4h\n" + "ld1 {v22.8b}, [x15], %[input_depth]\n" + "smlal2 v25.4s, v1.8h, v10.8h\n" + "ld1 {v23.8b}, [x15]\n" + "smlal v24.4s, v2.4h, v11.4h\n" + "subs w14, w14, #2\n" + "smlal2 v25.4s, v2.8h, v11.8h\n" + "cmp w14, #3\n" + "smlal v24.4s, v3.4h, v12.4h\n" + "add x11, x11, %[input_width_increment]\n" + "smlal2 v25.4s, v3.8h, v12.8h\n" + "mov x12, x11\n" + "smlal v26.4s, v3.4h, v14.4h\n" + "add x13, x12, %[input_row_size]\n" + "smlal2 v27.4s, v3.8h, v14.8h\n" + "add x15, x13, %[input_row_size]\n" + "smlal v24.4s, v4.4h, v13.4h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v4.8h, v13.8h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal v24.4s, v5.4h, v14.4h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v5.8h, v14.8h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal v24.4s, v6.4h, v15.4h\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "smlal2 v25.4s, v6.8h, v15.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v26.4s, v6.4h, v17.4h\n" + "ld1 {v15.8b}, [x15], %[input_depth]\n" + "smlal2 v27.4s, v6.8h, v17.8h\n" + "smlal v24.4s, v7.4h, v16.4h\n" + "smlal2 v25.4s, v7.8h, v16.8h\n" + "ld1 {v16.8b}, [x15], %[input_depth]\n" + "smlal v24.4s, v8.4h, v17.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v25.4s, v8.8h, v17.8h\n" + "ld1 {v17.8b}, [x15], %[input_depth]\n" + "uaddw v19.8h, v28.8h, v19.8b\n" + + "smlal v26.4s, v1.4h, v18.4h\n" + "uaddw v20.8h, v28.8h, v20.8b\n" + "smlal2 v27.4s, v1.8h, v18.8h\n" + "smlal v26.4s, v2.4h, v19.4h\n" + "uaddw v21.8h, v28.8h, v21.8b\n" + "smlal2 v27.4s, v2.8h, v19.8h\n" + "smlal v26.4s, v4.4h, v20.4h\n" + "smlal v26.4s, v5.4h, v21.4h\n" + "smlal2 v27.4s, v4.8h, v20.8h\n" + "uaddw v22.8h, v28.8h, v22.8b\n" + "smlal2 v27.4s, v5.8h, v21.8h\n" + "uaddw v23.8h, v28.8h, v23.8b\n" + "smlal v26.4s, v7.4h, v22.4h\n" + "smlal2 v27.4s, v7.8h, v22.8h\n" + "smlal v26.4s, v8.4h, v23.4h\n" + "smlal2 v27.4s, v8.8h, v23.8h\n" + + "dup v28.4s, w1\n" + "dup v29.4s, w9\n" + "sqrdmulh v24.4s, v24.4s, v28.4s\n" + "sqrdmulh v25.4s, v25.4s, v28.4s\n" + "sqrdmulh v26.4s, v26.4s, v28.4s\n" + "sqrdmulh v27.4s, v27.4s, v28.4s\n" + "dup v28.4s, w2\n" + "and v30.16b, v24.16b, v29.16b\n" + "and v31.16b, v25.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v24.4s, v24.4s, v30.4s\n" + "sqadd v25.4s, v25.4s, v31.4s\n" + "and v30.16b, v26.16b, v29.16b\n" + "and v31.16b, v27.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v26.4s, v26.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v27.4s, v27.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v24.4s, v24.4s, v29.4s\n" + "srshl v25.4s, v25.4s, v29.4s\n" + "srshl v26.4s, v26.4s, v29.4s\n" + "srshl v27.4s, v27.4s, v29.4s\n" + "add v24.4s, v24.4s, v28.4s\n" + "add v25.4s, v25.4s, v28.4s\n" + "add v26.4s, v26.4s, v28.4s\n" + "add v27.4s, v27.4s, v28.4s\n" + "dup v28.8h, w0\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smax v27.4s, v27.4s, v30.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "smin v27.4s, v27.4s, v31.4s\n" + "sqxtn v24.4h, v24.4s\n" + "sqxtn v26.4h, v26.4s\n" + "sqxtn2 v24.8h, v25.4s\n" + "ld1 {v25.4s}, [x10]\n" + "sqxtn2 v26.8h, v27.4s\n" + "ld1 {v27.4s}, [x10]\n" + "sqxtun v24.8b, v24.8h\n" + "sqxtun v26.8b, v26.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v24.8b}, [x6], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v26.8b}, [x6], x5\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "ld1 {v24.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "ld1 {v26.4s}, [%[bias_ptr]]\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w14, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + // Handle last two horizontal outputs if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER ":\n" + "smlal v24.4s, v0.4h, v9.4h\n" + "ld1 {v18.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v0.8h, v9.8h\n" + "ld1 {v19.8b}, [x12]\n" + "smlal v26.4s, v0.4h, v11.4h\n" + "ld1 {v20.8b}, [x13], %[input_depth]\n" + "smlal2 v27.4s, v0.8h, v11.8h\n" + "ld1 {v21.8b}, [x13]\n" + "smlal v24.4s, v1.4h, v10.4h\n" + "ld1 {v22.8b}, [x15], %[input_depth]\n" + "smlal2 v25.4s, v1.8h, v10.8h\n" + "ld1 {v23.8b}, [x15]\n" + "smlal v24.4s, v2.4h, v11.4h\n" + "smlal2 v25.4s, v2.8h, v11.8h\n" + "smlal v24.4s, v3.4h, v12.4h\n" + "smlal2 v25.4s, v3.8h, v12.8h\n" + "smlal v26.4s, v3.4h, v14.4h\n" + "smlal2 v27.4s, v3.8h, v14.8h\n" + "smlal v24.4s, v4.4h, v13.4h\n" + "smlal2 v25.4s, v4.8h, v13.8h\n" + "smlal v24.4s, v5.4h, v14.4h\n" + "smlal2 v25.4s, v5.8h, v14.8h\n" + "smlal v24.4s, v6.4h, v15.4h\n" + "smlal2 v25.4s, v6.8h, v15.8h\n" + "smlal v26.4s, v6.4h, v17.4h\n" + "smlal2 v27.4s, v6.8h, v17.8h\n" + "smlal v24.4s, v7.4h, v16.4h\n" + "smlal2 v25.4s, v7.8h, v16.8h\n" + "smlal v24.4s, v8.4h, v17.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v25.4s, v8.8h, v17.8h\n" + "uaddw v19.8h, v28.8h, v19.8b\n" + + "smlal v26.4s, v1.4h, v18.4h\n" + "uaddw v20.8h, v28.8h, v20.8b\n" + "smlal2 v27.4s, v1.8h, v18.8h\n" + "smlal v26.4s, v2.4h, v19.4h\n" + "uaddw v21.8h, v28.8h, v21.8b\n" + "smlal2 v27.4s, v2.8h, v19.8h\n" + "smlal v26.4s, v4.4h, v20.4h\n" + "smlal v26.4s, v5.4h, v21.4h\n" + "smlal2 v27.4s, v4.8h, v20.8h\n" + "uaddw v22.8h, v28.8h, v22.8b\n" + "smlal2 v27.4s, v5.8h, v21.8h\n" + "uaddw v23.8h, v28.8h, v23.8b\n" + "smlal v26.4s, v7.4h, v22.4h\n" + "smlal2 v27.4s, v7.8h, v22.8h\n" + "smlal v26.4s, v8.4h, v23.4h\n" + "smlal2 v27.4s, v8.8h, v23.8h\n" + + "dup v28.4s, w1\n" + "dup v29.4s, w9\n" + "sqrdmulh v24.4s, v24.4s, v28.4s\n" + "sqrdmulh v25.4s, v25.4s, v28.4s\n" + "sqrdmulh v26.4s, v26.4s, v28.4s\n" + "sqrdmulh v27.4s, v27.4s, v28.4s\n" + "dup v28.4s, w2\n" + "and v30.16b, v24.16b, v29.16b\n" + "and v31.16b, v25.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v24.4s, v24.4s, v30.4s\n" + "sqadd v25.4s, v25.4s, v31.4s\n" + "and v30.16b, v26.16b, v29.16b\n" + "and v31.16b, v27.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v26.4s, v26.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v27.4s, v27.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v24.4s, v24.4s, v29.4s\n" + "srshl v25.4s, v25.4s, v29.4s\n" + "srshl v26.4s, v26.4s, v29.4s\n" + "srshl v27.4s, v27.4s, v29.4s\n" + "add v24.4s, v24.4s, v28.4s\n" + "add v25.4s, v25.4s, v28.4s\n" + "add v26.4s, v26.4s, v28.4s\n" + "add v27.4s, v27.4s, v28.4s\n" + "dup v28.8h, w0\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smax v27.4s, v27.4s, v30.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "smin v27.4s, v27.4s, v31.4s\n" + "sqxtn v24.4h, v24.4s\n" + "sqxtn v26.4h, v26.4s\n" + "sqxtn2 v24.8h, v25.4s\n" + "sqxtn2 v26.8h, v27.4s\n" + "sqxtun v24.8b, v24.8h\n" + "sqxtun v26.8b, v26.8h\n" + "st1 {v24.8b}, [x6], x5\n" + "st1 {v26.8b}, [x6]\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + // Handle bottom right output if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER ":\n" + "dup v26.4s, w9\n" + "dup v27.4s, w1\n" + "dup v29.4s, w2\n" + + "smlal v24.4s, v0.4h, v9.4h\n" + "smlal2 v25.4s, v0.8h, v9.8h\n" + "smlal v24.4s, v1.4h, v10.4h\n" + "smlal2 v25.4s, v1.8h, v10.8h\n" + "smlal v24.4s, v2.4h, v11.4h\n" + "smlal2 v25.4s, v2.8h, v11.8h\n" + "smlal v24.4s, v3.4h, v12.4h\n" + "smlal2 v25.4s, v3.8h, v12.8h\n" + "smlal v24.4s, v4.4h, v13.4h\n" + "smlal2 v25.4s, v4.8h, v13.8h\n" + "smlal v24.4s, v5.4h, v14.4h\n" + "smlal2 v25.4s, v5.8h, v14.8h\n" + "smlal v24.4s, v6.4h, v15.4h\n" + "smlal2 v25.4s, v6.8h, v15.8h\n" + "smlal v24.4s, v7.4h, v16.4h\n" + "smlal2 v25.4s, v7.8h, v16.8h\n" + "smlal v24.4s, v8.4h, v17.4h\n" + "smlal2 v25.4s, v8.8h, v17.8h\n" + + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "sqrdmulh v25.4s, v25.4s, v27.4s\n" + "and v18.16b, v24.16b, v26.16b\n" + "and v19.16b, v25.16b, v26.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v24.4s, v24.4s, v18.4s\n" + "sqadd v25.4s, v25.4s, v19.4s\n" + "srshl v24.4s, v24.4s, v26.4s\n" + "srshl v25.4s, v25.4s, v26.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "add v25.4s, v25.4s, v29.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "sqxtn v24.4h, v24.4s\n" + "sqxtn2 v24.8h, v25.4s\n" + "sqxtun v24.8b, v24.8h\n" + "st1 {v24.8b}, [x6]\n" + + DEPTHWISECONV_LABEL_HEIGHT_1_END ":\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), + [output_window_height] "+r"(output_window_height) + : + // Inputs. + [bias_ptr] "r"(bias_ptr), [input_row_size] "r"(input_row_size), + [input_depth] "r"(input_depth), + [output_window_width] "r"(output_window_width), + [input_width_increment] "r"(input_width_increment), + [input_height_increment] "r"(input_height_increment), + [output_height_increment] "r"(output_height_increment), + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x9", "x10", "x11", "x12", "x13", "x14", "x15", + "x19", "x20"); +#undef DEPTHWISECONV_LABEL_HEIGHT_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1 +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_END } }; -template <> -struct ConvKernel3x3FilterDepth8<4, 1, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - const int output_row_size = output_depth * output_width; - - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - DotProductAndStore( - filter, input_6, input_7, input_8, input_0, input_1, input_2, input_3, - input_4, input_5, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Third output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - - DotProductAndStore( - filter, input_3, input_4, input_5, input_6, input_7, input_8, input_0, - input_1, input_2, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Fourth output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<2, 2, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - Int32x8 acc_0, acc_1, acc_2, acc_3; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_2.low = vld1q_s32(bias_ptr); - acc_3.low = vld1q_s32(bias_ptr); - - bias_ptr += 4; - acc_0.high = vld1q_s32(bias_ptr); - acc_1.high = vld1q_s32(bias_ptr); - acc_2.high = vld1q_s32(bias_ptr); - acc_3.high = vld1q_s32(bias_ptr); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - // Add scope for input registers to help the compiler know that it is - // not needed. - { - // To process 2x2 outputs using a 3x3 filter at stride 2, we require - // 5x5 inputs. We load the first 5x2 inputs at a time. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9; - - const uint8* ptr = input_ptr; - - // Load inputs. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; - - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next inputs. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - // Moving onto the two bottom outputs. - acc_2 = MultiplyAccumulateRow(acc_2, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_3 = MultiplyAccumulateRow(acc_3, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_2 = MultiplyAccumulateRow(acc_2, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_3 = MultiplyAccumulateRow(acc_3, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load last input row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; +enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter }; - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - } - - acc_2 = MultiplyAccumulateRow(acc_2, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_3 = MultiplyAccumulateRow(acc_3, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - } - - DownquantizeAndStore2x2Output(acc_0, acc_1, acc_2, acc_3, output_offset, - output_multiplier, output_shift, - output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - } -}; +template +struct DepthwiseConvPartial {}; template <> -struct ConvKernel3x3FilterDepth8<2, 4, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - // Reuse 2x2 kernel twice. - ConvKernel3x3FilterDepth8<2, 2, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth, - output_width); - - ConvKernel3x3FilterDepth8<2, 2, 2, 2>::Run( - input_ptr + 4 * input_depth, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr + 2 * output_depth, output_depth, output_width); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 1x1 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the 1x1 input and filter values. + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w10\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "cmp x11, #16\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w10, w10\n" + "dup v29.4s, w10\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w10\n" + "dup v25.8h, w9\n" + + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "subs x11, x11, #8\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "cmp x11, #16\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v8", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", + // We use these general-purpose registers. + "x9", "x10", "x11"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; template <> -struct ConvKernel3x3FilterDepth8<2, 1, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - const int output_row_size = output_depth * output_width; - - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - DotProductAndStore( - filter, input_6, input_7, input_8, input_0, input_1, input_2, input_3, - input_4, input_5, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 2x2 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 2x2 input and + // filter values. + + // Load input and filter values. + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "ldr x9, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "cmp x15, #16\n" + "add x12, %[input_ptr], x15\n" + "add x13, %[input_ptr], x9\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "add x14, x13, x15\n" + "ld1 {v9.8b}, [x12], #8\n" + "ldr x6, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + + "add x9, %[filter_ptr], x15\n" + "ld1 {v10.8b}, [x13], #8\n" + "add x10, %[filter_ptr], x6\n" + "ld1 {v11.8b}, [x14], #8\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "add x11, x10, x15\n" + "ld1 {v1.8b}, [x9], #8\n" + "ld1 {v2.8b}, [x10], #8\n" + "ld1 {v3.8b}, [x11], #8\n" + + // Load constants. + "ldr w6, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w7\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w7, w7\n" + "dup v29.4s, w7\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w7\n" + "dup v25.8h, w6\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "subs x15, x15, #8\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "cmp x15, #16\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], #8\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "ld1 {v1.8b}, [x9], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], #8\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v2.8b}, [x10], #8\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x14], #8\n" + "ld1 {v3.8b}, [x11], #8\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + // We use these general-purpose registers. + "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; template <> -struct ConvKernel3x3FilterDepth8<1, 2, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_depth; - - ptr = input_ptr + 3 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - DotProductAndStore( - filter, input_2, input_0, input_1, input_5, input_3, input_4, input_8, - input_6, input_7, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 2x3 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 2x3 input and + // filter values. + + // Load input and filter values. + "ldr x7, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n" + "mov x12, %[input_ptr]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "mov x9, %[filter_ptr]\n" + "ldr x14, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + "add x13, x12, x11\n" + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + + "ld1 {v8.8b}, [x12], x7\n" + "add x10, x9, x14\n" + "ld1 {v9.8b}, [x12], x7\n" + "cmp x15, #16\n" + "ld1 {v10.8b}, [x12]\n" + "add %[input_ptr], %[input_ptr], #8\n" + "ld1 {v11.8b}, [x13], x7\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "ld1 {v12.8b}, [x13], x7\n" + "ld1 {v13.8b}, [x13]\n" + + "ld1 {v0.8b}, [x9], x7\n" + "ld1 {v1.8b}, [x9], x7\n" + "ld1 {v2.8b}, [x9]\n" + "ld1 {v3.8b}, [x10], x7\n" + "ld1 {v4.8b}, [x10], x7\n" + "ld1 {v5.8b}, [x10]\n" + + // Load constants. + "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w13, w13\n" + "dup v29.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w13\n" + "dup v25.8h, w12\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "mov x12, %[input_ptr]\n" + "subs x15, x15, #8\n" + "add x13, x12, x11\n" + "cmp x15, #16\n" + "add %[input_ptr], %[input_ptr], #8\n" + + "smlal v16.4s, v0.4h, v8.4h\n" + "mov x9, %[filter_ptr]\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [x12], x7\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "add x10, x9, x14\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], x7\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x12]\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v0.8b}, [x9], x7\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x13], x7\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "ld1 {v1.8b}, [x9], x7\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "ld1 {v12.8b}, [x13], x7\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "ld1 {v2.8b}, [x9]\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + "ld1 {v13.8b}, [x13]\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "ld1 {v3.8b}, [x10], x7\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "ld1 {v4.8b}, [x10], x7\n" + "and v18.16b, v16.16b, v29.16b\n" + "ld1 {v5.8b}, [x10]\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12", + "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; template <> -struct ConvKernel3x3FilterDepth8<1, 4, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_depth; - - ptr = input_ptr + 3 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - DotProductAndStore( - filter, input_2, input_0, input_1, input_5, input_3, input_4, input_8, - input_6, input_7, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Third output. - output_ptr += output_depth; - - ptr = input_ptr + 5 * input_depth; - temp_2 = vld1_u8(ptr); - temp_0 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_5 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_8 = vld1_u8(ptr); - temp_6 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - - DotProductAndStore( - filter, input_1, input_2, input_0, input_4, input_5, input_3, input_7, - input_8, input_6, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Fourth output. - output_ptr += output_depth; - - ptr = input_ptr + 7 * input_depth; - temp_1 = vld1_u8(ptr); - temp_2 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_7 = vld1_u8(ptr); - temp_8 = vld1_u8(ptr + input_depth); - - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 3x2 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 3x2 input and + // filter values. + + // Load input and filter values. + "ldr x6, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n" + "mov x12, %[input_ptr]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "mov x7, %[filter_ptr]\n" + "ldr x5, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + "add x13, x12, x11\n" + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "add x14, x13, x11\n" + + "ld1 {v8.8b}, [x12], x6\n" + "add x9, x7, x5\n" + "ld1 {v9.8b}, [x12]\n" + "cmp x15, #16\n" + "add x10, x9, x5\n" + "ld1 {v10.8b}, [x13], x6\n" + "add %[input_ptr], %[input_ptr], #8\n" + "ld1 {v11.8b}, [x13]\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "ld1 {v12.8b}, [x14], x6\n" + "ld1 {v13.8b}, [x14]\n" + + "ld1 {v0.8b}, [x7], x6\n" + "ld1 {v1.8b}, [x7]\n" + "ld1 {v2.8b}, [x9], x6\n" + "ld1 {v3.8b}, [x9]\n" + "ld1 {v4.8b}, [x10], x6\n" + "ld1 {v5.8b}, [x10]\n" + + // Load constants. + "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w13, w13\n" + "dup v29.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w13\n" + "dup v25.8h, w12\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "mov x12, %[input_ptr]\n" + "subs x15, x15, #8\n" + "add x13, x12, x11\n" + "cmp x15, #16\n" + "add x14, x13, x11\n" + "add %[input_ptr], %[input_ptr], #8\n" + + "smlal v16.4s, v0.4h, v8.4h\n" + "mov x7, %[filter_ptr]\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [x12], x6\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "add x9, x7, x5\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "add x10, x9, x5\n" + "ld1 {v9.8b}, [x12]\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], x6\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v0.8b}, [x7], x6\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x13]\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "ld1 {v1.8b}, [x7]\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "ld1 {v12.8b}, [x14], x6\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "ld1 {v2.8b}, [x9], x6\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + "ld1 {v13.8b}, [x14]\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "ld1 {v3.8b}, [x9]\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "ld1 {v4.8b}, [x10], x6\n" + "and v18.16b, v16.16b, v29.16b\n" + "ld1 {v5.8b}, [x10]\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12", + "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x5", "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; -template -struct ConvKernel3x3FilterDepth8<1, 1, kFixedStrideWidth, kFixedStrideHeight> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - - uint8x8_t temp_0 = vld1_u8(input_ptr); - uint8x8_t temp_1 = vld1_u8(input_ptr + input_depth); - uint8x8_t temp_2 = vld1_u8(input_ptr + 2 * input_depth); - - input_ptr += input_row_size; - uint8x8_t temp_3 = vld1_u8(input_ptr); - uint8x8_t temp_4 = vld1_u8(input_ptr + input_depth); - uint8x8_t temp_5 = vld1_u8(input_ptr + 2 * input_depth); - - input_ptr += input_row_size; - uint8x8_t temp_6 = vld1_u8(input_ptr); - uint8x8_t temp_7 = vld1_u8(input_ptr + input_depth); - uint8x8_t temp_8 = vld1_u8(input_ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - } -}; - -inline void ShuffleInput(const uint8* input_ptr, int input_depth, - int input_width, int input_height, int output_depth, - int output_width, int output_height, - uint8* output_ptr) { - const int input_row_size = input_depth * input_width; - - for (int y = 0; y < output_height; y++) { +#undef OFFSET_INPUT_DEPTH +#undef OFFSET_INPUT_ROW_SIZE +#undef OFFSET_OUTPUT_DEPTH +#undef OFFSET_OUTPUT_ROW_SIZE +#undef OFFSET_INPUT_OFFSET +#undef OFFSET_OUTPUT_OFFSET +#undef OFFSET_FILTER_OFFSET +#undef OFFSET_OUTPUT_MULTIPLIER +#undef OFFSET_OUTPUT_ACTIVATION_MIN +#undef OFFSET_OUTPUT_ACTIVATION_MAX +#undef OFFSET_OUTPUT_SHIFT +#undef OFFSET_INPUT_WIDTH +#undef OFFSET_INPUT_HEIGHT +#undef OFFSET_OUTPUT_WIDTH +#undef OFFSET_OUTPUT_HEIGHT +#undef STR +#undef STR_UNEXPANDED + +// Copies a subset of the input designated by |input_ptr| into |output_ptr| +// with the specified output dimensions. Supports output depths of 64 only as +// this is the cache line size. +inline void ShuffleInput(const uint8* input_ptr, int64_t input_depth, + int32 input_width, int32 input_height, + int64_t output_depth, int32 output_width, + int32 output_height, uint8* output_ptr) { + const int64_t input_row_size = input_depth * input_width; + for (int32 y = 0; y < output_height; y++) { const uint8* ptr = input_ptr; - for (int x = 0; x < output_width; x++) { + for (int32 x = 0; x < output_width; x++) { memcpy(output_ptr, ptr, output_depth); output_ptr += output_depth; ptr += input_depth; @@ -3873,561 +2937,262 @@ inline void ShuffleInput(const uint8* input_ptr, int input_depth, } } -template -struct ConvRow3x3FilterDepth8 {}; - -template -struct ConvRow3x3FilterDepth8<1, kFixedStrideWidth, kFixedStrideHeight> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - - // 1x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<1, 4, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 4 * kFixedStrideWidth * input_depth; - output_data += 4 * output_depth; - } - - // 1x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<1, 1, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } +// Calculates the input size depending on stride and output. +inline int32 get_shuffle_input_size(int32 stride, int32 output) { + return stride * (output - 1) + 3; +} - input_data += kFixedStrideWidth * input_depth; - output_data += output_depth; - } +// Indicates the input and output dimensions used when shuffling input +// activations. +struct ShuffleParams { + int32 output_width; + int32 output_height; + int32 input_width; + int32 input_height; + + ShuffleParams() = default; + ShuffleParams(int32 output_width, int32 output_height, int32 stride_width, + int32 stride_height) + : output_width(output_width) + , output_height(output_height) + , input_width(get_shuffle_input_size(stride_width, output_width)) + , input_height(get_shuffle_input_size(stride_height, output_height)) { } }; -template -struct ConvRow3x3FilterDepth8<2, kFixedStrideWidth, kFixedStrideHeight> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - - // 2x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<2, 4, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 4 * kFixedStrideWidth * input_depth; - output_data += 4 * output_depth; - } - - // 2x2 at a time. - for (; out_x <= output_width - 2; out_x += 2) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<2, 2, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 2 * kFixedStrideWidth * input_depth; - output_data += 2 * output_depth; - } - - // 2x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<2, 1, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += kFixedStrideWidth * input_depth; - output_data += output_depth; +template +struct DepthwiseConvThroughDepth { + // Runs the DepthwiseConvWindow kernels through the depth dimension from + // |start_depth| to |end_depth|. Keep this not inlined to maintain a small + // binary size. We use a DepthwiseConvParams struct for read only params + // to minimize call overhead. + static __attribute__((noinline)) void Run(const uint8* input_ptr, + const uint8* filter_ptr, const int32* bias_ptr, uint8* output_ptr, + int64_t start_depth, int64_t end_depth, int64_t input_depth, + int64_t input_row_size, int32 output_window_height, + int32 output_window_width, const DepthwiseConvParams& params) { + for (; start_depth <= end_depth - 8; start_depth += 8) { + DepthwiseConvWindow<8, kStrideWidth, kStrideHeight>::Run( + input_ptr, filter_ptr, bias_ptr, output_ptr, input_depth, + input_row_size, output_window_height, output_window_width, ¶ms); + input_ptr += 8; + output_ptr += 8; + filter_ptr += 8; + bias_ptr += 8; } } }; -template <> -struct ConvRow3x3FilterDepth8<4, 1, 1> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - - // 4x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 4, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } +template +struct DepthwiseConvMultiRow { + using ConvKernel = DepthwiseConvThroughDepth; - input_data += 4 * input_depth; - output_data += 4 * output_depth; - } - - // Handle the rest of the right side. - // 4x2 at a time. - for (; out_x <= output_width - 2; out_x += 2) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 2, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 2 * input_depth; - output_data += 2 * output_depth; - } - - // 4x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 1, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += input_depth; - output_data += output_depth; - } - } -}; - -template <> -struct ConvRow3x3FilterDepth8<4, 2, 2> { - // The buffer size of the shuffled input. - static inline constexpr int ShuffleWorkspaceSize() { return 64 * 9 * 9; } - - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, + static inline void Run(const uint8* input_data, int32 start_x, int32 end_x, + const uint8* filter_data, const int32* bias_data, + uint8* output_data, const DepthwiseConvParams& params, + const ShuffleParams& shuffle_params, uint8* shuffle_workspace) { - // Branch and cache misses increase substantially with stride 2 kernels. - // Adding prefetching reduces latency by as much as 2x. - const int i0 = 0; - const int i1 = input_depth; - const int i2 = 2 * input_depth; - const int i3 = 3 * input_depth; - const int i4 = 4 * input_depth; - const int i5 = 5 * input_depth; - const int i6 = 6 * input_depth; - const int i7 = 7 * input_depth; - const int i8 = 8 * input_depth; - -#define DEPTHWISECONV_PRELOAD_ROW(input_ptr, i) \ - preload_l1_keep(input_ptr + i * input_row_size + i0); \ - preload_l1_keep(input_ptr + i * input_row_size + i1); \ - preload_l1_keep(input_ptr + i * input_row_size + i2); \ - preload_l1_keep(input_ptr + i * input_row_size + i3); \ - preload_l1_keep(input_ptr + i * input_row_size + i4); \ - preload_l1_keep(input_ptr + i * input_row_size + i5); \ - preload_l1_keep(input_ptr + i * input_row_size + i6); \ - preload_l1_keep(input_ptr + i * input_row_size + i7); \ - preload_l1_keep(input_ptr + i * input_row_size + i8); - - int out_x = start_x; - // 4x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - int depth = 0; - for (; depth <= output_depth - 64; depth += 64) { - // Preload 9x9 input. - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 0); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 1); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 2); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 3); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 4); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 5); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 6); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 7); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 8); - - // For a large input window (64x9x9) that is small enough to fit in L1 - // cache, copy the input into a separate buffer and run the kernel on - // this new buffer. This reduces the likelihood of cache misses when - // the kernel is loading input data. If this size is ever changed, - // update the ShuffleWorkspaceSize() function to return the new size. - ShuffleInput(input_ptr, input_depth, input_width, input_height, 64, 9, - 9, shuffle_workspace); - const uint8* shuffled_ptr = &shuffle_workspace[0]; - - for (int micro_depth = 0; micro_depth <= 64 - 8; micro_depth += 8) { - ConvKernel3x3FilterDepth8<4, 4, 2, 2>::Run( - shuffled_ptr, 64, input_offset, 64 * 9, filter_ptr, filter_offset, - bias_ptr, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, - output_depth, output_width); - - shuffled_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; + TFLITE_DCHECK(shuffle_params.input_height == + get_shuffle_input_size(kStrideHeight, shuffle_params.output_height)); + TFLITE_DCHECK(shuffle_params.input_width == + get_shuffle_input_size(kStrideWidth, shuffle_params.output_width)); + TFLITE_DCHECK(64 * shuffle_params.input_width * shuffle_params.input_height + <= DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE); + + int32 out_x = start_x; + + // Run shuffling on inputs with sufficiently large depth and width. When + // these parameters are large enough, more time is taken to load inputs + // from memory. At this point, it becomes useful to prefetch and + // preshuffle the input data to maximize locality. + if (params.output_depth > 64 || + (params.output_depth <= 64 && params.input_width > 150)) { + for (; out_x <= (end_x - shuffle_params.output_width); + out_x += shuffle_params.output_width) { + const uint8* input_ptr = input_data; + const int32* bias_ptr = bias_data; + const uint8* filter_ptr = filter_data; + uint8* output_ptr = output_data; + int64_t depth = 0; + const int64_t shuffle_row_size = 64 * shuffle_params.input_width; + + for (; depth <= params.output_depth - 64; depth += 64) { + // Preload. + const uint8* h_ptr = input_ptr; + for (int32 i = 0; i < shuffle_params.input_height; i++) { + const uint8* ptr = h_ptr; + for (int32 j = 0; j < shuffle_params.input_width; j++) { + asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :); + ptr += params.input_depth; + } + h_ptr += params.input_row_size; + } + + // For a large enough input, shuffle into buckets. + ShuffleInput(input_ptr, params.input_depth, params.input_width, + params.input_height, 64, shuffle_params.input_width, + shuffle_params.input_height, shuffle_workspace); + ConvKernel::Run(shuffle_workspace, filter_ptr, bias_ptr, output_ptr, + 0, 64, 64, shuffle_row_size, + shuffle_params.output_height, + shuffle_params.output_width, params); + input_ptr += 64; + output_ptr += 64; + filter_ptr += 64; + bias_ptr += 64; } - input_ptr += 64; - } - - // Preload 9x9 input one more time for the rest of the depth. - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 0); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 1); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 2); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 3); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 4); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 5); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 6); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 7); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 8); - - for (; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 4, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 4 * 2 * input_depth; - output_data += 4 * output_depth; - } - -#undef DEPTHWISECONV_PRELOAD_ROW - - // Handle the rest of the right side. - // 4x2 at a time. - for (; out_x <= output_width - 2; out_x += 2) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; + // Preload. + const uint8* h_ptr = input_ptr; + for (int32 i = 0; i < shuffle_params.input_height; i++) { + const uint8* ptr = h_ptr; + for (int32 j = 0; j < shuffle_params.input_width; j++) { + asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :); + ptr += params.input_depth; + } + h_ptr += params.input_row_size; + } - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); + // Handle leftover depth. + ConvKernel::Run(input_ptr, filter_ptr, bias_ptr, output_ptr, + depth, params.output_depth, params.input_depth, + params.input_row_size, shuffle_params.output_height, + shuffle_params.output_width, params); - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; + input_data += + shuffle_params.output_width * kStrideWidth * params.input_depth; + output_data += shuffle_params.output_width * params.output_depth; } - - input_data += 2 * 2 * input_depth; - output_data += 2 * output_depth; } - // 4x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 1, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 2 * input_depth; - output_data += output_depth; + const int32 output_leftover_width = end_x - out_x; + if (output_leftover_width > 0) { + ConvKernel::Run(input_data, filter_data, bias_data, output_data, 0, + params.output_depth, params.input_depth, + params.input_row_size, shuffle_params.output_height, + output_leftover_width, params); } } }; -template <> -struct ConvRow3x3FilterDepth8<8, 2, 2> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - // Reuse 4 row kernels twice. - ConvRow3x3FilterDepth8<4, 2, 2>::Run( - input_data, start_x, start_y, input_depth, input_width, input_height, - input_row_size, input_offset, filter_data, filter_offset, bias_data, - output_offset, output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_depth, output_width, - shuffle_workspace); - - ConvRow3x3FilterDepth8<4, 2, 2>::Run( - input_data + 2 * 4 * input_row_size, start_x, start_y + 4, input_depth, - input_width, input_height, input_row_size, input_offset, filter_data, - filter_offset, bias_data, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_data + 4 * output_depth * output_width, output_depth, - output_width, shuffle_workspace); +// Processes the borders of the input for pad_width and pad_height = 1. +// Calls 4 asm kernels: +// * 1x1 input shape. +// * Corner edges. +// * Horizontal edges. +// * Vertical edges. +inline void DepthwiseConvHandlePadding(const uint8* input_data, + const uint8* filter_data, const int32* bias_data, uint8* output_data, + const DepthwiseConvParams& params) { + if (params.input_width == 1 && params.input_height == 1) { + const uint8* filter_ptr = filter_data + params.filter_row_size + + params.output_depth; + DepthwiseConvPartial::Run(input_data, filter_ptr, + bias_data, output_data, ¶ms); + return; } -}; -template <> -struct ConvRow3x3FilterDepth8<8, 1, 1> { - // The buffer size of the shuffled input. - static inline constexpr int ShuffleWorkspaceSize() { return 64 * 10 * 10; } - - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - // 8x8 at a time. - for (; out_x <= output_width - 8; out_x += 8) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - int depth = 0; - for (; depth <= output_depth - 64; depth += 64) { - // For a large input window (64x10x10) that is small enough to fit in L1 - // cache, copy the input into a separate buffer and run the kernel on - // this new buffer. This reduces the likelihood of cache misses when - // the kernel is loading input data. If the size of the input window - // changes, update the function ShuffleWorkspaceSize() with the new - // size. - ShuffleInput(input_ptr, input_depth, input_width, input_height, 64, 10, - 10, shuffle_workspace); - const uint8* shuffled_ptr = shuffle_workspace; - - for (int micro_depth = 0; micro_depth <= 64 - 8; micro_depth += 8) { - ConvKernel3x3FilterDepth8<8, 8, 1, 1>::Run( - shuffled_ptr, 64, input_offset, 64 * 10, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - shuffled_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - input_ptr += 64; - } + const int32 out_x_start_corner = 0; + const int32 out_x_end_corner = params.output_width - 1; + const int32 out_y_start_corner = 0; + const int32 out_y_end_corner = params.output_height - 1; + + // Handle top row. + const uint8* input_ptr = input_data; + const uint8* filter_ptr = filter_data + params.filter_row_size + + params.output_depth; + uint8* output_ptr = output_data; + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); + + input_ptr += (params.stride_width - 1) * params.input_depth; + filter_ptr = filter_data + params.filter_row_size; + output_ptr += params.output_depth; + + for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner; + out_x++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_depth; + output_ptr += params.output_depth; + } - for (; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<8, 8, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); - input_data += 8 * input_depth; - output_data += 8 * output_depth; - } + // Handle left side. + input_ptr = input_data + (params.stride_width - 1) * params.input_row_size; + filter_ptr = filter_data + params.input_depth; + output_ptr = output_data + params.output_row_size; - // Handle the rest of the right side by re-using 4 row kernels twice. - ConvRow3x3FilterDepth8<4, 1, 1>::Run( - input_data, out_x, start_y, input_depth, input_width, input_height, - input_row_size, input_offset, filter_data, filter_offset, bias_data, - output_offset, output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_depth, output_width, - shuffle_workspace); - - ConvRow3x3FilterDepth8<4, 1, 1>::Run( - input_data + 4 * input_row_size, out_x, start_y + 4, input_depth, - input_width, input_height, input_row_size, input_offset, filter_data, - filter_offset, bias_data, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_data + 4 * output_depth * output_width, output_depth, - output_width, shuffle_workspace); + for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner; + out_y++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_row_size; + output_ptr += params.output_row_size; } -}; -inline bool Fast3x3FilterKernelSupported(const Dims<4>& input_dims, - const Dims<4>& filter_dims, - int stride_width, int stride_height, - int pad_width, int pad_height, - int depth_multiplier, - const Dims<4>& output_dims) { - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int input_depth = ArraySize(input_dims, 0); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - - bool supported = filter_width == 3 && filter_height == 3 && - depth_multiplier == 1 && - (stride_width == 1 || stride_width == 2) && - (stride_height == 1 || stride_height == 2) && - (stride_width == stride_height) && pad_width == 0 && - pad_height == 0 && (input_depth % 8) == 0; + // Handle right side. + input_ptr = input_data + (params.input_width - 2) * params.input_depth + + (params.stride_width - 1) * params.input_row_size; + filter_ptr = filter_data; + output_ptr = output_data + params.output_row_size + + (params.output_width - 1) * params.output_depth; + + for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner; + out_y++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_row_size; + output_ptr += params.output_row_size; + } + + // Handle bottom row. + input_ptr = input_data + (params.input_height - 2) * params.input_row_size; + filter_ptr = filter_data + params.output_depth; + output_ptr = output_data + + (params.output_height - 1) * params.output_row_size; + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); + + input_ptr += (params.stride_width == 1) ? 0 : params.input_depth; + filter_ptr = filter_data; + output_ptr += params.output_depth; + + for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner; + out_x++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_depth; + output_ptr += params.output_depth; + } + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); +} + +inline bool Fast3x3FilterKernelSupported( + const Dims<4>& input_dims, const Dims<4>& filter_dims, int32 stride_width, + int32 stride_height, int32 pad_width, int32 pad_height, + int32 depth_multiplier, const Dims<4>& output_dims, int32 output_shift) { + const int32 input_height = ArraySize(input_dims, 2); + const int32 input_width = ArraySize(input_dims, 1); + const int32 input_depth = ArraySize(input_dims, 0); + const int32 filter_height = ArraySize(filter_dims, 2); + const int32 filter_width = ArraySize(filter_dims, 1); + const int32 output_height = ArraySize(output_dims, 2); + const int32 output_width = ArraySize(output_dims, 1); + + bool supported = + filter_width == 3 && filter_height == 3 && depth_multiplier == 1 && + (stride_width == 1 || stride_width == 2) && + (stride_height == 1 || stride_height == 2) && + (stride_width == stride_height) && (pad_width == 0 || pad_width == 1) && + (pad_height == 0 || pad_height == 1) && (pad_width == pad_height) && + (input_depth % 8) == 0 && (output_shift > 0); if (!supported) { return false; @@ -4436,145 +3201,193 @@ inline bool Fast3x3FilterKernelSupported(const Dims<4>& input_dims, // Handle case where padding is zero but padding type is not kValid. // This would require special boundary case handling that is not supported. - const int out_x = output_width - 1; - const int out_y = output_height - 1; + const int32 out_x = output_width - 1; + const int32 out_y = output_height - 1; - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int32 in_x_origin = (out_x * stride_width) - pad_width; + const int32 in_y_origin = (out_y * stride_height) - pad_height; - const int in_x_end = in_x_origin + filter_width; - const int in_y_end = in_y_origin + filter_height; + const int32 in_x_end = in_x_origin + filter_width; + const int32 in_y_end = in_y_origin + filter_height; // Supported only if filter on the right and bottom boundary lies completely - // within the input. - return in_x_end <= input_width && in_y_end <= input_height; + // within the input if padding is zero. + if (pad_width == 0 && pad_height == 0) { + return in_x_end <= input_width && in_y_end <= input_height; + } + + // Else if padding is 1, supported if bottom right filter lies +1 past input + // width and height. + supported = in_x_end <= (input_width + 1) && in_y_end <= (input_height + 1); + + if (!supported) { + return false; + } + + // Shapes with width 1 and height > 1, and vice versa are not supported yet. + if (input_width == 1) { + supported = (input_width == input_height); + } else if (input_height == 1) { + supported = (input_width == input_height); + } + return supported; } inline void DepthwiseConv3x3Filter( const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, int stride_width, - int stride_height, int pad_width, int pad_height, int depth_multiplier, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int input_depth = ArraySize(input_dims, 0); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - - // Algorithm assumes below constraints. It is optimized for depth multiplier - // of 1, 3x3 filter, no padding and strides 1 and 2. - TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + const int32* bias_data, const Dims<4>& bias_dims, int32 stride_width, + int32 stride_height, int32 pad_width, int32 pad_height, + int32 depth_multiplier, int32 output_offset, int32 output_multiplier, + int32 output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + DepthwiseConvParams params; + params.input_depth = ArraySize(input_dims, 0); + params.input_width = ArraySize(input_dims, 1); + params.input_height = ArraySize(input_dims, 2); + params.input_row_size = params.input_depth * params.input_width; + params.input_offset = input_offset; + params.stride_width = stride_width; + params.stride_height = stride_height; + params.output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + params.output_width = ArraySize(output_dims, 1); + params.output_height = ArraySize(output_dims, 2); + params.output_row_size = params.output_depth * params.output_width; + params.output_offset = output_offset; + params.filter_offset = filter_offset; + params.output_multiplier = output_multiplier; + params.output_shift = output_shift; + params.output_activation_min = output_activation_min; + params.output_activation_max = output_activation_max; + + const int32 filter_height = ArraySize(filter_dims, 2); + const int32 filter_width = ArraySize(filter_dims, 1); + params.filter_row_size = params.output_depth * filter_width; + + // Algorithm assumes below constraints. It is optimized for depth + // multiplier of 1, 3x3 filter, no padding and strides 1 and 2. + TFLITE_DCHECK(params.output_depth == params.input_depth * depth_multiplier); TFLITE_DCHECK(depth_multiplier == 1); TFLITE_DCHECK(filter_height == 3); TFLITE_DCHECK(filter_width == 3); - TFLITE_DCHECK(pad_height == 0); - TFLITE_DCHECK(pad_width == 0); TFLITE_DCHECK(stride_height == 1 || stride_height == 2); TFLITE_DCHECK(stride_width == 1 || stride_width == 2); TFLITE_DCHECK(stride_width == stride_height); + TFLITE_DCHECK(pad_height == 0 || pad_height == 1); + TFLITE_DCHECK(pad_width == 0 || pad_width == 1); + TFLITE_DCHECK(pad_width == pad_height); + + const int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int64_t input_batch_size = params.input_row_size * params.input_height; + const int64_t output_batch_size = + params.output_row_size * params.output_height; + + ShuffleParams one_row_shuffle_params, two_row_shuffle_params, + four_row_shuffle_params, eight_row_shuffle_params; + if (stride_width == 1) { + one_row_shuffle_params = ShuffleParams(30, 1, 1, 1); + two_row_shuffle_params = ShuffleParams(22, 2, 1, 1); + four_row_shuffle_params = ShuffleParams(14, 4, 1, 1); + eight_row_shuffle_params = ShuffleParams(8, 8, 1, 1); + } else { + one_row_shuffle_params = ShuffleParams(14, 1, 2, 2); + two_row_shuffle_params = ShuffleParams(8, 2, 2, 2); + four_row_shuffle_params = ShuffleParams(4, 4, 2, 2); + eight_row_shuffle_params = ShuffleParams(2, 8, 2, 2); + } - const int input_row_size = input_depth * (input_width + 2 * pad_width); - const int output_row_size = output_depth * output_width; - const int input_batch_size = input_row_size * (input_height + 2 * pad_height); - const int output_batch_size = output_depth * output_width * output_height; - - using conv_row_func_t = decltype(&ConvRow3x3FilterDepth8<1, 1, 1>::Run); - conv_row_func_t conv_1_output_row = ConvRow3x3FilterDepth8<1, 1, 1>::Run; - conv_row_func_t conv_2_output_rows = ConvRow3x3FilterDepth8<2, 1, 1>::Run; - conv_row_func_t conv_4_output_rows = ConvRow3x3FilterDepth8<4, 1, 1>::Run; - conv_row_func_t conv_8_output_rows = ConvRow3x3FilterDepth8<8, 1, 1>::Run; - + using conv_multirow_func_t = decltype(&DepthwiseConvMultiRow<1, 1>::Run); + conv_multirow_func_t conv_multirow_func = DepthwiseConvMultiRow<1, 1>::Run; if (stride_width == 2) { - conv_1_output_row = ConvRow3x3FilterDepth8<1, 2, 2>::Run; - conv_2_output_rows = ConvRow3x3FilterDepth8<2, 2, 2>::Run; - conv_4_output_rows = ConvRow3x3FilterDepth8<4, 2, 2>::Run; - conv_8_output_rows = ConvRow3x3FilterDepth8<8, 2, 2>::Run; + conv_multirow_func = DepthwiseConvMultiRow<2, 2>::Run; } // Allocate maximum memory needed for shuffled input. // TODO(mariewhite): The size of this workspace is small enough to be // allocated on the stack. Eventually we will want to move it to the heap - // and have it allocated outside of this function, like the im2col_array used - // in gemmlowp. -#define DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE 10 * 10 * 64 + // and have it allocated outside of this function, like the im2col_array + // used in gemmlowp. uint8 shuffle_workspace[DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE]; - // Make sure the kernels using this buffer will not run out of bounds. - static_assert(ConvRow3x3FilterDepth8<8, 1, 1>::ShuffleWorkspaceSize() <= - DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE, - "Shuffle workspace size is too small."); - static_assert(ConvRow3x3FilterDepth8<4, 2, 2>::ShuffleWorkspaceSize() <= - DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE, - "Shuffle workspace size is too small."); - -#undef DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE - - for (int b = 0; b < batches; ++b) { + for (int32 b = 0; b < batches; ++b) { const uint8* input_ptr = input_data + b * input_batch_size; uint8* output_ptr = output_data + b * output_batch_size; - int out_y = 0; + int32 out_x = 0; + int32 out_y = 0; + int32 end_x = params.output_width; + int32 end_y = params.output_height; + + if (pad_width == 1 && pad_height == 1) { + DepthwiseConvHandlePadding(input_ptr, filter_data, bias_data, output_ptr, + params); + + // Update extents now that the edges have been handled. + out_x = 1; + end_x = params.output_width - 1; + out_y = 1; + end_y = params.output_height - 1; + const int in_x = (out_x * stride_width) - pad_width; + const int in_y = (out_y * stride_height) - pad_height; + input_ptr += in_y * params.input_row_size + in_x * params.input_depth; + output_ptr += out_y * params.output_row_size + + out_x * params.output_depth; + } + + // Shuffling shapes that maximize width over the shuffle workspace size + // perform better since the inputs are closer together, minimizing + // shuffling time. + // + // If the input shape has width large enough for the 2 row kernels, + // we prefer to use this. The innermost loop of the kernels handle + // 2 height x 2 width so this is the fastest path. + // + // If the input shape has smaller width but larger height, shuffling is + // still useful and can benefit from kernels 4 row and 8 row kernels. // Handle 8 rows at a time. - for (; out_y <= output_height - 8; out_y += 8) { - conv_8_output_rows(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, - filter_data, filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += 8 * stride_height * input_row_size; - output_ptr += 8 * output_row_size; + if (params.input_width < four_row_shuffle_params.input_width) { + for (; out_y <= end_y - 8; out_y += 8) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, eight_row_shuffle_params, + shuffle_workspace); + input_ptr += 8 * stride_height * params.input_row_size; + output_ptr += 8 * params.output_row_size; + } } // Handle 4 rows at a time. - for (; out_y <= output_height - 4; out_y += 4) { - conv_4_output_rows(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, - filter_data, filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += 4 * stride_height * input_row_size; - output_ptr += 4 * output_row_size; + if (params.input_width < two_row_shuffle_params.input_width) { + for (; out_y <= end_y - 4; out_y += 4) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, four_row_shuffle_params, + shuffle_workspace); + input_ptr += 4 * stride_height * params.input_row_size; + output_ptr += 4 * params.output_row_size; + } } // Handle 2 rows at a time. - for (; out_y <= output_height - 2; out_y += 2) { - conv_2_output_rows(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, - filter_data, filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += 2 * stride_height * input_row_size; - output_ptr += 2 * output_row_size; + for (; out_y <= end_y - 2; out_y += 2) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, two_row_shuffle_params, + shuffle_workspace); + input_ptr += 2 * stride_height * params.input_row_size; + output_ptr += 2 * params.output_row_size; } // Handle one row at a time. - for (; out_y < output_height; out_y++) { - conv_1_output_row(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, filter_data, - filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += stride_height * input_row_size; - output_ptr += output_row_size; + for (; out_y < end_y; out_y++) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, one_row_shuffle_params, + shuffle_workspace); + input_ptr += stride_height * params.input_row_size; + output_ptr += params.output_row_size; } } } +// clang-format on #endif // __aarch64__ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h index 0bfb4e9b1f8ee4167cfb629645a38538be1d73d4..27d9224512a835ea58911031f1b4d6dcf5482ba9 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -129,8 +129,9 @@ class EigenTensorConvFunctor { const int conv_width = output_height * output_width; Eigen::array, 1> dim_pair; dim_pair[0] = Eigen::IndexPair(1, 0); - EigenMatrix output(output_data, conv_width, filter_count); - ConstEigenMatrix input(input_data, conv_width, input_depth); + EigenMatrix output(output_data, input_batches * conv_width, filter_count); + ConstEigenMatrix input(input_data, input_batches * conv_width, + input_depth); ConstEigenMatrix filter(filter_data, input_depth, filter_count); MatMulConvFunctor()(device, output, input, filter, dim_pair); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index 65f25168e3a20748973549c2f7385b14863294eb..38ad32c734a2286c7d23162810625169a4d8df43 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -56,9 +56,12 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1)); // The arrays used to cache the vector. + void* aligned_vector_cache_free = nullptr; float32x4_t* vector_cache_float32x4 = - new float32x4_t[(m_cols / kFloatWeightsPerNeonLane) * - sizeof(float32x4_t)]; + reinterpret_cast(aligned_alloc( + sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t), + &aligned_vector_cache_free)); + const int kUnrollSize = 2; for (int b = 0; b < n_batch; b++) { float* result_in_batch = result + b * m_rows * result_stride; @@ -71,7 +74,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, matrix_ptr1 = matrix + m_cols; } - // Cahce the vector. + // Cache the vector. for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c); } @@ -128,7 +131,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, result_in_batch += result_stride; } } - delete[] vector_cache_float32x4; + free(aligned_vector_cache_free); } void NeonMatrixBatchVectorMultiplyAccumulate( @@ -294,9 +297,12 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); // The arrays used to cache the vector. + void* aligned_vector_cache_free = nullptr; float32x4_t* vector_cache_float32x4 = - new float32x4_t[(v_size / kFloatWeightsPerNeonLane) * - sizeof(float32x4_t)]; + reinterpret_cast(aligned_alloc( + sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t), + &aligned_vector_cache_free)); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v); } @@ -322,7 +328,7 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, result_ptr += v_size; batch_vector_ptr += v_size; } - delete[] vector_cache_float32x4; + free(aligned_vector_cache_free); } void NeonSub1Vector(const float* vector, int v_size, float* result) { @@ -346,6 +352,30 @@ void NeonSub1Vector(const float* vector, int v_size, float* result) { } } +bool NeonIsZeroVector(const float* vector, int v_size) { + // If v_size is not divisible by kFloatWeightsPerNeonLane, we cannot + // use the main vectorized loop, and we need to process sequentially. + // postamble_start shows the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + const float32x4_t zero_x4_float = vmovq_n_f32(0.0f); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + const float32x4_t i_x4_float = vld1q_f32(vector + v); + uint32x4_t cmp_result = vceqq_f32(i_x4_float, zero_x4_float); + if (vgetq_lane_u32(cmp_result, 0) == 0) return false; + if (vgetq_lane_u32(cmp_result, 1) == 0) return false; + if (vgetq_lane_u32(cmp_result, 2) == 0) return false; + if (vgetq_lane_u32(cmp_result, 3) == 0) return false; + } + + // Postamble loop + for (int v = postamble_start; v < v_size; ++v) { + if (vector[v] != 0.0) return false; + } + return true; +} + void NeonClipVector(const float* vector, int v_size, float abs_limit, float* result) { // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index 9e60d0657b49ed5ee1f2e999acb86ebee0eab972..7a5a8fc54123946229963abd1720030d0bb358bf 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -100,6 +100,11 @@ void ZeroVector(float* vector, int v_size) { float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } +// Check if all entries of a vector are zero. +bool IsZeroVector(const float* vector, int v_size) { + return NEON_OR_PORTABLE(IsZeroVector, vector, v_size); +} + void ClipVector(const float* vector, int v_size, float abs_limit, float* result) { NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 580d208bebc7cd21825d4ec75e373ceeafb91ba0..8115a072d5008b116c83c208e011453fe2196996 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -48,6 +48,15 @@ using reference_ops::Greater; using reference_ops::GreaterEqual; using reference_ops::Less; using reference_ops::LessEqual; +using reference_ops::RankOneSelect; +using reference_ops::Select; + +// TODO(b/80247582) Remove this constant. +// This will be phased out as the shifts are revised with more thought. Use of a +// constant enables us to track progress on this work. +// +// Used mainly to convert from old-style shifts (right) to new-style (left). +static constexpr int kReverseShift = -1; // Make a local VectorMap typedef allowing to map a float array // as a Eigen vector expression. The std::conditional here is to @@ -65,7 +74,7 @@ using VectorMap = typename std::conditional< template VectorMap MapAsVector(Scalar* data, const Dims& dims) { - const int size = RequiredBufferSizeForDims(dims); + const int size = FlatSize(dims); return VectorMap(data, size, 1); } @@ -138,6 +147,45 @@ MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, return MatrixMap(data, rows, cols); } +// This is like the template-parameter version, except that the power-of-two is +// passed as a function parameter. The template version is to be preferred, +// since some target hardware optimizations depend on the range of the exponent. +template +IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { + if (exponent == 0) { + return x; + } + using ScalarIntegerType = + typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; + const IntegerType min = + gemmlowp::Dup(std::numeric_limits::min()); + const IntegerType max = + gemmlowp::Dup(std::numeric_limits::max()); + const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); + + const std::int32_t threshold = + ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); + const IntegerType positive_mask = + gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); + const IntegerType negative_mask = + gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); + + IntegerType result = gemmlowp::ShiftLeft(x, exponent); + result = gemmlowp::SelectUsingMask(positive_mask, max, result); + result = gemmlowp::SelectUsingMask(negative_mask, min, result); + return result; +} + +// This is like the template-parameter version, except that the power-of-two is +// passed as a function parameter. See raw-integer version for further comments. +template +gemmlowp::FixedPoint +SaturatingRoundingMultiplyByPOTParam( + gemmlowp::FixedPoint a, int exponent) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); +} + // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE // BROADCASTING. // @@ -247,8 +295,8 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data, float output_activation_max) { #ifdef USE_NEON gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); - const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3]; - const int array_size = array_dims.sizes[3] * array_dims.strides[3]; + const int bias_size = FlatSize(bias_dims); + const int array_size = FlatSize(array_dims); TFLITE_DCHECK_EQ((array_size % bias_size), 0); float* array_ptr = array_data; float* array_end_ptr = array_ptr + array_size; @@ -298,8 +346,8 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data, } #else // not NEON gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); - const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3]; - const int array_size = array_dims.sizes[3] * array_dims.strides[3]; + const int bias_size = FlatSize(bias_dims); + const int array_size = FlatSize(array_dims); TFLITE_DCHECK_EQ((array_size % bias_size), 0); for (int array_offset = 0; array_offset < array_size; array_offset += bias_size) { @@ -370,10 +418,8 @@ inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims, TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3), - 1); - const int input_size = input_dims.strides[3]; + TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1); + const int input_size = FlatSizeSkipDim(input_dims, 3); const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0); // This special fast path for quantized LSTM cells does not try to support // odd sizes that we haven't encountered in any LSTM cell, that would @@ -556,10 +602,8 @@ inline void GEMVForLstmCellWithSymmetricRange( TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3), - 1); - const int input_size = input_dims.strides[3]; + TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1); + const int input_size = FlatSizeSkipDim(input_dims, 3); const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0); // This special fast path for quantized LSTM cells does not try to support // odd sizes that we haven't encountered in any LSTM cell, that would @@ -892,10 +936,8 @@ inline void FullyConnectedAsGEMV( TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3), - 1); - const int input_size = input_dims.strides[3]; + TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1); + const int input_size = FlatSizeSkipDim(input_dims, 3); const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); static constexpr int kPeel = 4; for (int k = 0; k < input_size; k += 64) { @@ -1076,8 +1118,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, // but the current --variable_batch hack consists in overwriting the 3rd // dimension with the runtime batch size, as we don't keep track for each // array of which dimension is the batch dimension in it. - const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3); + const int batches = FlatSizeSkipDim(output_dims, 0); #ifdef USE_NEON const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); if (batches == 1 && !(output_size % 4)) { @@ -1133,8 +1174,7 @@ inline void FullyConnected( // but the current --variable_batch hack consists in overwriting the 3rd // dimension with the runtime batch size, as we don't keep track for each // array of which dimension is the batch dimension in it. - const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3); + const int batches = FlatSizeSkipDim(output_dims, 0); const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0); const int accum_depth = ArraySize(filter_dims, 0); TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); @@ -1549,8 +1589,7 @@ inline void ExperimentalShuffledFullyConnected( // but the current --variable_batch hack consists in overwriting the 3rd // dimension with the runtime batch size, as we don't keep track for each // array of which dimension is the batch dimension in it. - const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3); + const int batches = FlatSizeSkipDim(output_dims, 0); const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0); const int accum_depth = ArraySize(weights_dims, 0); TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); @@ -1737,6 +1776,100 @@ inline void ExtractPatchIntoBufferColumn( } } +template +void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, + const Dims<4>& filter_dims, int stride_width, + int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + const Dims<4>& output_dims, uint8 byte_zero, + T* im2col_data) { + // For dilated convolution, the input pixels are not contiguous therefore we + // can't use the same opitimizations as Im2Col(). Though note this code would + // work fine for the non-dilated case too (though likely a bit slower). + gemmlowp::ScopedProfilingLabel label("DilatedIm2col"); + TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK(im2col_data); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + MatchingArraySize(output_dims, 0, filter_dims, 3); + + // Construct the MxN sized im2col matrix. + // The rows M, are sub-ordered B x H x W + Dims<4> row_dims; + row_dims.sizes[0] = output_width; + row_dims.sizes[1] = output_height; + row_dims.sizes[2] = batches; + row_dims.sizes[3] = 1; + ComputeStrides(&row_dims); + + // The columns, N, are sub-ordered Kh x Kw x Din + Dims<4> col_dims; + col_dims.sizes[0] = input_depth; + col_dims.sizes[1] = filter_width; + col_dims.sizes[2] = filter_height; + col_dims.sizes[3] = 1; + ComputeStrides(&col_dims); + + // Use dimensions M and N to construct dims for indexing directly into im2col + Dims<4> im2col_dims; + im2col_dims.sizes[0] = col_dims.strides[3]; + im2col_dims.sizes[1] = row_dims.strides[3]; + im2col_dims.sizes[2] = 1; + im2col_dims.sizes[3] = 1; + ComputeStrides(&im2col_dims); + + // Loop through the output rows (B x H x W) + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + // Each row is an output pixel. Arrange the input data into this row in + // an order we can conveniently multiply with the filter data. + int row_offset = Offset(row_dims, out_x, out_y, batch, 0); + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Loop through all the pixels of the filter (Kh x Kw) + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int in_y = in_y_origin + dilation_height_factor * filter_y; + if ((in_y >= 0) && (in_y < input_height)) { + // Filter row is within the input data. + // Loop through all the filter pixels in this row. + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + dilation_width_factor * filter_x; + int col_offset = Offset(col_dims, 0, filter_x, filter_y, 0); + T* dst = im2col_data + + Offset(im2col_dims, col_offset, row_offset, 0, 0); + if ((in_x >= 0) && (in_x < input_width)) { + // Filter pixel is within the input, copy the data. + T const* src = + input_data + Offset(input_dims, 0, in_x, in_y, batch); + memcpy(dst, src, input_depth * sizeof(T)); + } else { + // Filter pixel is outside the input, zero it out. + memset(dst, byte_zero, input_depth * sizeof(T)); + } + } + } else { + // Filter row is outside the input, zero out the entire im2col row. + int col_offset = Offset(col_dims, 0, 0, filter_y, 0); + T* dst = + im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0); + memset(dst, byte_zero, filter_width * input_depth * sizeof(T)); + } + } + } + } + } +} + template void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int kheight, @@ -1777,74 +1910,6 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, kwidth, byte_zero, output_data, output_dims); } -inline void DilatedConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, - int dilation_width_factor, int dilation_height_factor, - int pad_width, int pad_height, - float output_activation_min, - float output_activation_max, float* output_data, - const Dims<4>& output_dims, float* im2col_data, - const Dims<4>& im2col_dims) { - gemmlowp::ScopedProfilingLabel label("DilatedConv"); - // This is a copy of the reference Conv implementation. We do not currently - // have an optimized path for dilation. - (void)im2col_data; // only used in optimized code. - (void)im2col_dims; // only used in optimized code. - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); - const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); - if (bias_data) { - TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0)); - } - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; - float total = 0.f; - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - const int in_x = in_x_origin + dilation_width_factor * filter_x; - const int in_y = - in_y_origin + dilation_height_factor * filter_y; - // If the location is outside the bounds of the input image, - // use zero as a default value. - if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && - (in_y < input_height)) { - float input_value = input_data[Offset(input_dims, in_channel, - in_x, in_y, batch)]; - float filter_value = - filter_data[Offset(filter_dims, in_channel, filter_x, - filter_y, out_channel)]; - total += (input_value * filter_value); - } - } - } - } - float bias_value = 0.0f; - if (bias_data) { - bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; - } - output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = - ActivationFunctionWithMinMax(total + bias_value, - output_activation_min, - output_activation_max); - } - } - } - } -} - inline void Conv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, const float* bias_data, const Dims<4>& bias_dims, @@ -1853,29 +1918,32 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims, float* im2col_data, const Dims<4>& im2col_dims) { - if ((dilation_width_factor != 1) || (dilation_height_factor != 1)) { - return DilatedConv(input_data, input_dims, filter_data, filter_dims, - bias_data, bias_dims, stride_width, stride_height, - dilation_width_factor, dilation_height_factor, pad_width, - pad_height, output_activation_min, output_activation_max, - output_data, output_dims, im2col_data, im2col_dims); - } - (void)im2col_data; (void)im2col_dims; gemmlowp::ScopedProfilingLabel label("Conv"); + // A float set to 0x00000000h == 0.0f + const uint8 float_zero_byte = 0x00; const float* gemm_input_data = nullptr; const Dims<4>* gemm_input_dims = nullptr; const int filter_width = ArraySize(filter_dims, 1); const int filter_height = ArraySize(filter_dims, 2); + const bool need_dilated_im2col = + dilation_width_factor != 1 || dilation_height_factor != 1; const bool need_im2col = stride_width != 1 || stride_height != 1 || filter_width != 1 || filter_height != 1; - if (need_im2col) { + if (need_dilated_im2col) { + DilatedIm2col(input_data, input_dims, filter_dims, stride_width, + stride_height, dilation_width_factor, dilation_height_factor, + pad_width, pad_height, output_dims, float_zero_byte, + im2col_data); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else if (need_im2col) { TFLITE_DCHECK(im2col_data); Im2col(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_height, filter_width, 0, im2col_data, - im2col_dims); + pad_height, filter_height, filter_width, float_zero_byte, + im2col_data, im2col_dims); gemm_input_data = im2col_data; gemm_input_dims = &im2col_dims; } else { @@ -1986,13 +2054,21 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, } const int gemm_input_rows = gemm_input_dims->sizes[0]; + // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784). + // The root cause has not yet been identified though. Same applies below for + // the other calls commented out. This is a partial rollback of cl/196819423. + // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0); const int gemm_input_cols = gemm_input_dims->sizes[1] * gemm_input_dims->sizes[2] * gemm_input_dims->sizes[3]; const int filter_rows = filter_dims.sizes[3]; + // See b/79927784. + // const int filter_cols = FlatSizeSkipDim(filter_dims, 3); const int filter_cols = filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; const int output_rows = output_dims.sizes[0]; + // See b/79927784. + // const int output_cols = FlatSizeSkipDim(output_dims, 0); const int output_cols = output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; TFLITE_DCHECK_EQ(output_rows, filter_rows); @@ -2148,14 +2224,11 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, Ac == FusedActivationFunctionType::kRelu1, ""); const int input_rows = input_dims.sizes[0]; - const int input_cols = - input_dims.sizes[1] * input_dims.sizes[2] * input_dims.sizes[3]; + const int input_cols = FlatSizeSkipDim(input_dims, 0); const int filter_rows = filter_dims.sizes[3]; - const int filter_cols = - filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; + const int filter_cols = FlatSizeSkipDim(filter_dims, 3); const int output_rows = output_dims.sizes[0]; - const int output_cols = - output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + const int output_cols = FlatSizeSkipDim(output_dims, 0); TFLITE_DCHECK_EQ(output_rows, filter_rows); TFLITE_DCHECK_EQ(output_cols, input_cols); TFLITE_DCHECK_EQ(filter_cols, input_rows); @@ -2219,27 +2292,15 @@ void NonGlobalBatchNormalization( const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization"); const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = - MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2, - offset_dims, 2, output_dims, 2); - const int width = - MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1, - offset_dims, 1, output_dims, 1); - const int depth = - MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, - offset_dims, 0, output_dims, 0); + const int inner_size = MatchingFlatSizeSkipDim( + input_dims, 3, mean_dims, multiplier_dims, offset_dims, output_dims); for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( - (input_data[Offset(input_dims, c, x, y, b)] - - mean_data[Offset(mean_dims, c, x, y, 0)]) * - multiplier_data[Offset(multiplier_dims, c, x, y, 0)] + - offset_data[Offset(offset_dims, c, x, y, 0)]); - } - } + for (int i = 0; i < inner_size; ++i) { + *output_data = ActivationFunction( + (*input_data - mean_data[i]) * multiplier_data[i] + offset_data[i]); + ++output_data; + ++input_data; } } } @@ -2254,24 +2315,17 @@ void GlobalBatchNormalization(const float* input_data, const Dims<4>& offset_dims, float* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); const int depth = MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, offset_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( - (input_data[Offset(input_dims, c, x, y, b)] - - mean_data[Offset(mean_dims, c, 0, 0, 0)]) * - multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] + - offset_data[Offset(offset_dims, c, 0, 0, 0)]); - } - } + for (int i = 0; i < outer_size; ++i) { + for (int c = 0; c < depth; ++c) { + *output_data = ActivationFunction( + (*input_data - mean_data[c]) * multiplier_data[c] + offset_data[c]); + ++output_data; + ++input_data; } } } @@ -2288,44 +2342,26 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims, inline void Relu1(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - float val = input_data[Offset(input_dims, c, x, y, b)]; - const float upper = 1; - const float lower = -1; - float clamped = val > upper ? upper : val < lower ? lower : val; - output_data[Offset(output_dims, c, x, y, b)] = clamped; - } - } - } + const int flat_size = MatchingFlatSize(input_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + const float val = input_data[i]; + const float upper = 1; + const float lower = -1; + const float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[i] = clamped; } } inline void Relu6(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - float val = input_data[Offset(input_dims, c, x, y, b)]; - const float upper = 6; - const float lower = 0; - float clamped = val > upper ? upper : val < lower ? lower : val; - output_data[Offset(output_dims, c, x, y, b)] = clamped; - } - } - } + const int flat_size = MatchingFlatSize(input_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + const float val = input_data[i]; + const float upper = 6; + const float lower = 0; + const float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[i] = clamped; } } @@ -2334,24 +2370,19 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("L2Normalization"); static_assert(Ac == FusedActivationFunctionType::kNone, ""); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - float squared_l2_norm = 0; - for (int c = 0; c < depth; ++c) { - float val = input_data[Offset(input_dims, c, x, y, b)]; - squared_l2_norm += val * val; - } - float inverse_l2_norm = 1.0f / std::sqrt(squared_l2_norm); - for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - input_data[Offset(input_dims, c, x, y, b)] * inverse_l2_norm; - } - } + for (int i = 0; i < outer_size; ++i) { + float squared_l2_norm = 0; + for (int c = 0; c < depth; ++c) { + const float val = input_data[c]; + squared_l2_norm += val * val; + } + const float l2_norm = std::sqrt(squared_l2_norm); + for (int c = 0; c < depth; ++c) { + *output_data = *input_data / l2_norm; + ++output_data; + ++input_data; } } } @@ -2405,32 +2436,31 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, int32 input_zero_point, uint8* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - TFLITE_DCHECK_EQ(batches, 1); - TFLITE_DCHECK_EQ(height, 1); - TFLITE_DCHECK_EQ(width, 1); - int32 square_l2_norm = 0; - for (int i = 0; i < depth; i++) { - int32 diff = input_data[i] - input_zero_point; - square_l2_norm += diff * diff; - } - int32 inv_l2norm_multiplier; - int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); - - for (int i = 0; i < depth; i++) { - int32 diff = input_data[i] - input_zero_point; - int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( - 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); - int32 unclamped_output_val = 128 + rescaled_diff; - int32 output_val = std::min(255, std::max(0, unclamped_output_val)); - output_data[i] = static_cast(output_val); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); + for (int i = 0; i < outer_size; ++i) { + int32 square_l2_norm = 0; + for (int c = 0; c < depth; c++) { + int32 diff = input_data[c] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int c = 0; c < depth; c++) { + int32 diff = *input_data - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( + 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + *output_data = static_cast(output_val); + ++input_data; + ++output_data; + } } } @@ -2439,20 +2469,12 @@ inline void Add(const float* input1_data, const Dims<4>& input1_dims, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Add"); - /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, - output_dims, 3); - /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, - output_dims, 2); - /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, - output_dims, 1); - /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, - output_dims, 0); TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); int i = 0; - const int size = input1_dims.sizes[3] * input1_dims.strides[3]; + const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims); #ifdef USE_NEON const auto activation_min = vdupq_n_f32(output_activation_min); const auto activation_max = vdupq_n_f32(output_activation_max); @@ -2499,52 +2521,17 @@ inline void Add(const float* input1_data, const Dims<4>& input1_dims, } } -// legacy, for compatibility with old checked-in code -template -void Add(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float* output_data, const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, - output_activation_max, output_data, output_dims); -} - -template -inline void Add(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, int input2_shift, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - gemmlowp::ScopedProfilingLabel label("Add/8bit"); - /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, - output_dims, 3); - /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, - output_dims, 2); - /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, - output_dims, 1); - /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, - output_dims, 0); - TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - +// Element-wise add that can often be used for inner loop of broadcast add as +// well as the non-broadcast add. +inline void AddElementwise(int size, int left_shift, const uint8* input1_data, + int32 input1_offset, int32 input1_multiplier, + int input1_shift, const uint8* input2_data, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data) { int i = 0; - const int size = input1_dims.sizes[3] * input1_dims.strides[3]; TFLITE_DCHECK_GT(input1_offset, -256); TFLITE_DCHECK_GT(input2_offset, -256); TFLITE_DCHECK_LT(input1_offset, 256); @@ -2609,20 +2596,73 @@ inline void Add(int left_shift, const uint8* input1_data, const int32 input2_val = input2_offset + input2_data[i]; const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); - const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + - output_offset; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + + output_offset; const int32 clamped_output = std::min( output_activation_max, std::max(output_activation_min, raw_output)); output_data[i] = static_cast(clamped_output); } } +// legacy, for compatibility with old checked-in code +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + gemmlowp::ScopedProfilingLabel label("Add/8bit"); + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + TFLITE_DCHECK_GT(input1_offset, -256); + TFLITE_DCHECK_GT(input2_offset, -256); + TFLITE_DCHECK_LT(input1_offset, 256); + TFLITE_DCHECK_LT(input2_offset, 256); + AddElementwise(flat_size, left_shift, input1_data, input1_offset, + input1_multiplier, input1_shift, input2_data, input2_offset, + input2_multiplier, input2_shift, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data); +} + template inline void Add(const int16* input1_data, const Dims<4>& input1_dims, int input1_shift, const int16* input2_data, @@ -2643,9 +2683,7 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, TFLITE_DCHECK_EQ(output_activation_max, 32767); } - const int flat_size = RequiredBufferSizeForDims(output_dims); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0); TFLITE_DCHECK_GE(input1_shift, 0); @@ -2681,10 +2719,10 @@ void Add(const int32* input1_data, const Dims<4>& input1_dims, auto output_map = MapAsVector(output_data, output_dims); if (AreSameDims(input1_dims, input2_dims)) { output_map.array() = input1_map.array() + input2_map.array(); - } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + } else if (FlatSize(input2_dims) == 1) { auto scalar = input2_data[0]; output_map.array() = input1_map.array() + scalar; - } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + } else if (FlatSize(input1_dims) == 1) { auto scalar = input1_data[0]; output_map.array() = scalar + input2_map.array(); } else { @@ -2789,15 +2827,17 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -2833,27 +2873,11 @@ inline void BroadcastAddFivefold( input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { for (int i1 = 0; i1 < y1; ++i1) { - for (int i0 = 0; i0 < y0; ++i0) { - const int32 input1_val = input1_offset + input1_data_ptr[i0]; - const int32 input2_val = input2_offset + input2_data_ptr[i0]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); - const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + - output_offset; - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - output_data_ptr[i0] = static_cast(clamped_output); - } + AddElementwise( + y0, left_shift, input1_data_ptr, input1_offset, input1_multiplier, + input1_shift, input2_data_ptr, input2_offset, input2_multiplier, + input2_shift, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data_ptr); input2_data_ptr += y0; output_data_ptr += y0; } @@ -2924,20 +2948,12 @@ inline void Mul(const float* input1_data, const Dims<4>& input1_dims, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Mul"); - /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, - output_dims, 3); - /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, - output_dims, 2); - /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, - output_dims, 1); - /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, - output_dims, 0); TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); int i = 0; - const int size = input1_dims.sizes[3] * input1_dims.strides[3]; + const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims); #ifdef USE_NEON const auto activation_min = vdupq_n_f32(output_activation_min); const auto activation_max = vdupq_n_f32(output_activation_max); @@ -3012,10 +3028,10 @@ void Mul(const int32* input1_data, const Dims<4>& input1_dims, auto output_map = MapAsVector(output_data, output_dims); if (AreSameDims(input1_dims, input2_dims)) { output_map.array() = input1_map.array() * input2_map.array(); - } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + } else if (FlatSize(input2_dims) == 1) { auto scalar = input2_data[0]; output_map.array() = input1_map.array() * scalar; - } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + } else if (FlatSize(input1_dims) == 1) { auto scalar = input1_data[0]; output_map.array() = scalar * input2_map.array(); } else { @@ -3031,9 +3047,7 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, // This is a copy of the reference implementation. We do not currently have a // properly optimized version. - const int flat_size = RequiredBufferSizeForDims(output_dims); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -3055,9 +3069,7 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, // properly optimized version. TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int flat_size = RequiredBufferSizeForDims(output_dims); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -3166,9 +3178,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, const int32 input2_val = input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; const int32 unclamped_result = - output_offset + - MultiplyByQuantizedMultiplierSmallerThanOne( - input1_val * input2_val, output_multiplier, output_shift); + output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, output_multiplier, + kReverseShift * output_shift); const int32 clamped_output = std::min(output_activation_max, std::max(output_activation_min, unclamped_result)); @@ -3200,26 +3212,11 @@ inline void Div(const float* input1_data, const Dims<4>& input1_dims, const float* input2_data, const Dims<4>& input2_dims, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims) { - const int batches = - MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); - const int height = - MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); - const int width = - MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); - const int depth = - MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - ActivationFunctionWithMinMax( - input1_data[Offset(input1_dims, c, x, y, b)] / - input2_data[Offset(input2_dims, c, x, y, b)], - output_activation_min, output_activation_max); - } - } - } + const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); + for (int i = 0; i < flat_size; i++) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] / input2_data[i], output_activation_min, + output_activation_max); } } @@ -3273,26 +3270,12 @@ inline void Sub(const float* input1_data, const Dims<4>& input1_dims, const float* input2_data, const Dims<4>& input2_dims, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims) { - const int batches = - MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); - const int height = - MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); - const int width = - MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); - const int depth = - MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - ActivationFunctionWithMinMax( - input1_data[Offset(input1_dims, c, x, y, b)] - - input2_data[Offset(input2_dims, c, x, y, b)], - output_activation_min, output_activation_max); - } - } - } + gemmlowp::ScopedProfilingLabel label("Sub"); + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] - input2_data[i], output_activation_min, + output_activation_max); } } @@ -3379,15 +3362,17 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sub, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sub, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -3601,15 +3586,9 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, gemmlowp::ScopedProfilingLabel label( "LstmCell/quantized (8bit external, 16bit internal)"); // Gather dimensions information, and perform consistency checks. - const int batches = - MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, - output_state_dims, 3, output_activ_dims, 3); - const int height = - MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, - output_state_dims, 2, output_activ_dims, 2); - const int width = - MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, - output_state_dims, 1, output_activ_dims, 1); + const int outer_size = + MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims, + output_state_dims, output_activ_dims); TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); const int input_depth = ArraySize(input_dims, 0); @@ -3625,9 +3604,7 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, output_state_dims, 0, output_activ_dims, 0); TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); - const int fc_batches = ArraySize(activ_temp_dims, 1) * - ArraySize(activ_temp_dims, 2) * - ArraySize(activ_temp_dims, 3); + const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0); const int fc_output_depth = MatchingArraySize(weights_dims, 1, activ_temp_dims, 0); const int fc_accum_depth = ArraySize(weights_dims, 0); @@ -3683,7 +3660,6 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, // Rest of the LSTM cell: tanh and logistic math functions, and some adds // and muls, all done in 16-bit fixed-point. - const int outer_size = batches * width * height; const int16* input_gate_input_ptr = activ_temp_data_int16; const int16* input_modulation_gate_input_ptr = activ_temp_data_int16 + output_depth; @@ -3849,20 +3825,15 @@ void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, gemmlowp::ScopedProfilingLabel label("TensorFlowSplit"); TFLITE_DCHECK_GE(outputs_count, 1); for (int i = 0; i < outputs_count; i++) { - /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3); - /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); - /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); + MatchingFlatSizeSkipDim(*output_dims[i], 0, input_dims); } - const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3); - const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2); - const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1); + const int outer_size = FlatSizeSkipDim(input_dims, 0); TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - // for now we dont have a model with a TensorFlowSplit + // For now we don't have a model with a TensorFlowSplit // with fused activation function. TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); - const int whb = width * height * batches; const Scalar* input_ptr = input_data; - for (int k = 0; k < whb; k++) { + for (int k = 0; k < outer_size; k++) { for (int i = 0; i < outputs_count; ++i) { memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr, output_dims[i]->sizes[0] * sizeof(Scalar)); @@ -4387,10 +4358,7 @@ inline void LocalResponseNormalization(const float* input_data, float* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization"); - /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3); - /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2); - /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1); - /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0); + MatchingFlatSize(input_dims, output_dims); const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); @@ -4433,10 +4401,7 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims, float beta, float* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Softmax"); - /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3); - /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2); - /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1); - /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0); + MatchingFlatSize(input_dims, output_dims); const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); @@ -4468,13 +4433,9 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPoint0 = gemmlowp::FixedPoint; gemmlowp::ScopedProfilingLabel label("Softmax/8bit"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int outer_size = batches * height * width; - for (int b = 0; b < outer_size; ++b) { const uint8* input_data_ptr = input_data + b * depth; uint8* output_data_ptr = output_data + b * depth; @@ -4666,39 +4627,147 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("LogSoftmax"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - // Find max element value which we'll use to ensure numerical stability - // taking advantage of the following equality: - // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C))) - float max = std::numeric_limits::lowest(); - for (int c = 0; c < depth; ++c) { - max = std::max(max, input_data[Offset(input_dims, c, x, y, b)]); - } + for (int i = 0; i < outer_size; ++i) { + const float* block_input_data = input_data + i * depth; + float* block_output_data = output_data + i * depth; + // Find max element value which we'll use to ensure numerical stability + // taking advantage of the following equality: + // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C))) + float max = std::numeric_limits::lowest(); + for (int c = 0; c < depth; ++c) { + max = std::max(max, block_input_data[c]); + } - // Compute sum. - float sum = 0.f; - for (int c = 0; c < depth; ++c) { - sum += std::exp(input_data[Offset(input_dims, c, x, y, b)] - max); - } + // Compute sum. + float sum = 0.f; + for (int c = 0; c < depth; ++c) { + sum += std::exp(block_input_data[c] - max); + } - // Compute result. - const float log_sum = std::log(sum); - for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - input_data[Offset(input_dims, c, x, y, b)] - max - log_sum; - } - } + // Compute result. + const float log_sum = std::log(sum); + for (int c = 0; c < depth; ++c) { + block_output_data[c] = block_input_data[c] - max - log_sum; } } } +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1_impl( + gemmlowp::FixedPoint input_val) { + // assert(__builtin_clz(0u) >= std::numeric_limits::digits - 1); + // assert(__builtin_clz(0u) <= std::numeric_limits::digits); + using FixedPoint0 = gemmlowp::FixedPoint; + // The reason for accumulating the result with an extra bit of headroom is + // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * + // recip_denom will otherwise introduce an error. + static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; + using FixedPointAccum = gemmlowp::FixedPoint; + + const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1488522236, std::log(2.0)); + const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); + const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1518500250, std::sqrt(0.5)); + const FixedPoint0 one_quarter = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); + + const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1057819769, + 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); + + const FixedPointAccum shifted_quarter = + gemmlowp::Rescale(one_quarter); + + // Reinterpret the input value as Q0.31, because we will figure out the + // required shift "ourselves" instead of using, say, Rescale. + FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); + // z_a_pow_2 = input_integer_bits - z_a_headroom; + int z_a_headroom_plus_1 = __builtin_clz(static_cast(z_a.raw())); + FixedPoint0 r_a_tmp = + SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); + const int32 r_a_raw = + SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); + // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); + // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, + // InputIntegerBits - z_b_headroom - 0.25); + const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), + shifted_quarter); + + // z_b is treated like z_a, but premultiplying by sqrt(0.5). + FixedPoint0 z_b = z_a * sqrt_half; + int z_b_headroom = __builtin_clz(static_cast(z_b.raw())) - 1; + const int32 r_b_raw = + SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); + const FixedPointAccum z_b_pow_2_adj = SaturatingSub( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), + shifted_quarter); + + const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); + const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( + std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); + + const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); + FixedPoint0 q = r - sqrt_sqrt_half; + q = q + q; + + const FixedPoint0 common_sq = q * q; + const FixedPoint0 num = q * r + q * common_sq * alpha_n; + const FixedPoint0 denom_minus_one_0 = + p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; + const FixedPoint0 recip_denom = + one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); + + const FixedPointAccum num_scaled = gemmlowp::Rescale(num); + return gemmlowp::Rescale(z_pow_2_adj * log_2 + + num_scaled * recip_denom); +} + +// Minimum output bits to accommodate log of maximum input range. It actually +// does not matter if one considers, say, [-64,64] or [-64,64). +// +// For example, run this through Octave: +// [0:127; ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] +constexpr int min_log_x_output_bits(int input_bits) { + return input_bits > 90 + ? 7 + : input_bits > 44 + ? 6 + : input_bits > 21 + ? 5 + : input_bits > 10 + ? 4 + : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; +} + +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1( + gemmlowp::FixedPoint input_val) { + static_assert( + OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), + "Output integer bits must be sufficent to accommodate logs of inputs."); + return log_x_for_x_greater_than_or_equal_to_1_impl( + input_val); +} + // Currently just a copy of the reference code. inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, int32 input_multiplier, int32 input_left_shift, @@ -4723,15 +4792,16 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); for (int i = 0; i < outer_size; ++i) { + const uint8* block_input_data = input_data + i * depth; + uint8* block_output_data = output_data + i * depth; uint8 max_in_row = 0; for (int c = 0; c < depth; ++c) { - max_in_row = std::max(max_in_row, input_data[i * depth + c]); + max_in_row = std::max(max_in_row, block_input_data[c]); } FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); for (int c = 0; c < depth; ++c) { - int32 input_diff = - static_cast(input_data[i * depth + c]) - max_in_row; + int32 input_diff = static_cast(block_input_data[c]) - max_in_row; if (input_diff >= diff_min) { const int32 input_diff_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne( @@ -4743,13 +4813,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } - // TODO(b/77858996): Implement fixed-point log(). - // Not a fully-quantized implementation: floating-point log(). - const float float_log_sum_of_exps = - std::log(static_cast(sum_of_exps.raw()) / - (1 << (31 - kAccumulationIntegerBits))); - const int32 fixed_log_sum_of_exps = static_cast(TfLiteRound( - float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits)))); + const int32 fixed_log_sum_of_exps = + log_x_for_x_greater_than_or_equal_to_1( + sum_of_exps) + .raw(); // rescaled_diff_min is smallest representable in // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the @@ -4760,13 +4827,12 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, fixed_log_sum_of_exps + std::numeric_limits::lowest(); const int adjusted_diff_min = std::max(diff_min - 1, // Note use of > below instead of >= above. - MultiplyByQuantizedMultiplierSmallerThanOne( + MultiplyByQuantizedMultiplierSmallerThanOneExp( rescaled_diff_min, reverse_scaling_divisor, - reverse_scaling_right_shift)); + kReverseShift * reverse_scaling_right_shift)); for (int c = 0; c < depth; ++c) { - int32 input_diff = - static_cast(input_data[i * depth + c]) - max_in_row; + int32 input_diff = static_cast(block_input_data[c]) - max_in_row; if (input_diff > adjusted_diff_min) { const int32 input_diff_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne( @@ -4777,11 +4843,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, 31 - kScaledDiffIntegerBits - kOutputIntegerBits) + 255; - output_data[i * depth + c] = static_cast( + block_output_data[c] = static_cast( std::max(std::min(unsat_output, static_cast(255)), 0)); } else { // Set output to smallest value. - output_data[i * depth + c] = 0; + block_output_data[c] = 0; } } } @@ -4801,11 +4867,7 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, int32 input_multiplier, int input_left_shift, uint8* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Logistic/Uint8"); - /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3); - /* height */ MatchingArraySize(input_dims, 2, output_dims, 2); - /* width */ MatchingArraySize(input_dims, 1, output_dims, 1); - /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0); - const int size = RequiredBufferSizeForDims(input_dims); + const int size = MatchingFlatSize(input_dims, output_dims); int c = 0; #ifdef USE_NEON @@ -4940,8 +5002,7 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, inline void Logistic(const int16* input_data, const Dims<4>& input_dims, int16* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Logistic/Int16"); - const int flat_size = RequiredBufferSizeForDims(output_dims); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input_dims), flat_size); + const int flat_size = MatchingFlatSize(output_dims, input_dims); for (int i = 0; i < flat_size; i++) { } @@ -5012,11 +5073,7 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, uint8* output_data, const Dims<4>& output_dims) { // Note that this is almost the exact same code as in Logistic(). gemmlowp::ScopedProfilingLabel label("Tanh"); - /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3); - /* height */ MatchingArraySize(input_dims, 2, output_dims, 2); - /* width */ MatchingArraySize(input_dims, 1, output_dims, 1); - /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0); - const int size = RequiredBufferSizeForDims(input_dims); + const int size = MatchingFlatSize(input_dims, output_dims); int c = 0; int32_t output_zero_point = 128; @@ -5166,8 +5223,7 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims, TFLITE_DCHECK_GE(input_left_shift, 0); TFLITE_DCHECK_LE(input_left_shift, 1); - const int flat_size = RequiredBufferSizeForDims(output_dims); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input_dims), flat_size); + const int flat_size = MatchingFlatSize(output_dims, input_dims); int c = 0; const int16* input_data_ptr = input_data; @@ -5262,20 +5318,11 @@ inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, int32 zero_point, double scale, float* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Dequantize"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - int32 val = input_data[Offset(input_dims, c, x, y, b)]; - float result = static_cast(scale * (val - zero_point)); - output_data[Offset(output_dims, c, x, y, b)] = result; - } - } - } + const int flat_size = MatchingFlatSize(output_dims, input_dims); + for (int i = 0; i < flat_size; ++i) { + int32 val = input_data[i]; + float result = static_cast(scale * (val - zero_point)); + output_data[i] = result; } } @@ -5298,25 +5345,15 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, &nudged_max, &nudged_scale); const float inv_nudged_scale = 1.0f / nudged_scale; - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - const float src_val = input_data[Offset(input_dims, c, x, y, b)]; - const float clamped = - std::min(nudged_max, std::max(nudged_min, src_val)); - const float clamped_shifted = clamped - nudged_min; - const float dst_val = - TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + - nudged_min; - output_data[Offset(output_dims, c, x, y, b)] = dst_val; - } - } - } + const int flat_size = MatchingFlatSize(output_dims, input_dims); + for (int i = 0; i < flat_size; ++i) { + const float src_val = input_data[i]; + const float clamped = std::min(nudged_max, std::max(nudged_min, src_val)); + const float clamped_shifted = clamped - nudged_min; + const float dst_val = + TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + + nudged_min; + output_data[i] = dst_val; } } @@ -5685,6 +5722,46 @@ inline void ResizeBilinearGeneric(const float* input_data, } } +template +inline void ResizeBilinearGenericSmallChannel( + const T* input_data, const Dims<4>& input_dims, T* output_data, + const Dims<4>& output_dims, int32 batches, int32 input_height, + int32 input_width, int32 depth, int32 output_height, int32 output_width, + float height_scale, float width_scale) { + memset(output_data, 0, + batches * output_height * output_width * depth * sizeof(T)); + + T* output_ptr = &output_data[0]; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(input_x); + int32 x1 = std::min(x0 + 1, input_width - 1); + + int32 input_offset[4] = { + Offset(input_dims, 0, x0, y0, b), Offset(input_dims, 0, x1, y0, b), + Offset(input_dims, 0, x0, y1, b), Offset(input_dims, 0, x1, y1, b)}; + float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)), + (1 - (input_y - y0)) * (input_x - x0), + (input_y - y0) * (1 - (input_x - x0)), + (input_y - y0) * (input_x - x0)}; + + for (int d = 0; d < depth; d++) { + const T* input_ptr = &input_data[d]; + *output_ptr++ = static_cast(input_ptr[input_offset[0]] * scale[0] + + input_ptr[input_offset[1]] * scale[1] + + input_ptr[input_offset[2]] * scale[2] + + input_ptr[input_offset[3]] * scale[3]); + } + } + } + } +} + inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, @@ -5725,6 +5802,41 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } } +// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8 +// or int16 arithmetic. +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims, bool align_corners) { + gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + + float height_scale = + (align_corners && output_height > 1) + ? (static_cast(input_height - 1) / (output_height - 1)) + : (static_cast(input_height) / output_height); + + float width_scale = + (align_corners && output_width > 1) + ? (static_cast(input_width - 1) / (output_width - 1)) + : (static_cast(input_width) / output_width); + + ResizeBilinearGenericSmallChannel( + input_data, input_dims, output_data, output_dims, batches, input_height, + input_width, depth, output_height, output_width, height_scale, + width_scale); +} + // legacy, for compatibility with old checked-in code inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, @@ -5734,6 +5846,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, output_data, output_dims, /*align_corners=*/false); } +// legacy, for compatibility with old checked-in code +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, + output_data, output_dims, /*align_corners=*/false); +} + template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, @@ -6045,10 +6166,10 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; const int start_h = begin[2]; const int stop_h = - size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2]; + size[2] == -1 ? input_dims.sizes[2] - start_h : start_h + size[2]; const int start_w = begin[1]; const int stop_w = - size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1]; + size[1] == -1 ? input_dims.sizes[1] - start_w : start_w + size[1]; const int start_d = begin[0]; const int stop_d = size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; @@ -6147,10 +6268,10 @@ void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, auto output_map = MapAsVector(output_data, output_dims); if (AreSameDims(input1_dims, input2_dims)) { output_map.array() = input1_map.array() - input2_map.array(); - } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + } else if (FlatSize(input1_dims) == 1) { auto scalar = input1_data[0]; output_map.array() = scalar - input2_map.array(); - } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + } else if (FlatSize(input2_dims) == 1) { auto scalar = input2_data[0]; output_map.array() = input1_map.array() - scalar; } else { @@ -6188,32 +6309,28 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, // The current ArgMax implemention can only determine the index of the maximum // value in the last dimension. So the axis argument is ignored. - TFLITE_DCHECK_EQ(axis[0], 3); // For ArgMax, the number of output dimensions = (number of input dimensions - // 1). For the sake of simplicity, the output dimensions are equal to the // input dimensions here. We enforce the constraint that the last dimension // must always be 1. TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); const int depth = ArraySize(input_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - auto max_value = input_data[Offset(input_dims, 0, x, y, b)]; - int max_index = 0; - for (int d = 1; d < depth; ++d) { - const auto& curr_value = input_data[Offset(input_dims, d, x, y, b)]; - if (curr_value > max_value) { - max_value = curr_value; - max_index = d; - } - } - output_data[Offset(output_dims, 0, x, y, b)] = max_index; + for (int i = 0; i < outer_size; ++i) { + auto max_value = *input_data; + ++input_data; + int max_index = 0; + for (int d = 1; d < depth; ++d) { + const auto& curr_value = *input_data; + if (curr_value > max_value) { + max_value = curr_value; + max_index = d; } + ++input_data; } + *output_data = max_index; + ++output_data; } } @@ -6256,8 +6373,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, // To optimize, start by using the conv code with transposed weights for the // case of stride_height = stride_width = 1. const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); const int input_height = ArraySize(input_dims, 2); const int input_width = ArraySize(input_dims, 1); const int filter_height = ArraySize(filter_dims, 2); @@ -6293,7 +6410,7 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, const int out_y_origin = (in_y * stride_height) - pad_height; for (int filter_y = 0; filter_y < filter_height; ++filter_y) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - for (int out_channel = 0; out_channel < input_depth; + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { // Compute output element location const int out_x = out_x_origin + filter_x; @@ -6304,8 +6421,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, float input_value = input_data[Offset(input_dims, in_channel, in_x, in_y, batch)]; float filter_value = - filter_data[Offset(filter_dims, out_channel, filter_x, - filter_y, in_channel)]; + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] += input_value * filter_value; } @@ -6318,59 +6435,6 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } } -// UNOPTIMIZED COPY of Select from reference_ops.h. -template -inline void Select(const D* input_condition_data, - const Dims<4>& input_condition_dims, const T* input_x_data, - const Dims<4>& input_x_dims, const T* input_y_data, - const Dims<4>& input_y_dims, T* output_data, - const Dims<4>& output_dims) { - const int64_t batches = - MatchingArraySize(input_condition_dims, 3, input_x_dims, 3, input_y_dims, - 3, output_dims, 3); - const int64_t height = - MatchingArraySize(input_condition_dims, 2, input_x_dims, 2, input_y_dims, - 2, output_dims, 2); - const int64_t width = MatchingArraySize(input_condition_dims, 1, input_x_dims, - 1, input_y_dims, 1, output_dims, 1); - const int64_t depth = MatchingArraySize(input_condition_dims, 0, input_x_dims, - 0, input_y_dims, 0, output_dims, 0); - - const int64_t num_elements = batches * height * width * depth; - for (int64_t i = 0; i < num_elements; ++i) { - output_data[i] = - input_condition_data[i] ? input_x_data[i] : input_y_data[i]; - } -} - -// UNOPTIMIZED COPY of RankOneSelect from reference_ops.h. -template -inline void RankOneSelect(const D* input_condition_data, - const Dims<4>& input_condition_dims, - const T* input_x_data, const Dims<4>& input_x_dims, - const T* input_y_data, const Dims<4>& input_y_dims, - T* output_data, const Dims<4>& output_dims) { - const int64_t rank = ArraySize(input_condition_dims, 0); - - const int64_t batches = - MatchingArraySize(input_x_dims, 3, input_y_dims, 3, output_dims, 3); - const int64_t height = - MatchingArraySize(input_x_dims, 2, input_y_dims, 2, output_dims, 2); - const int64_t width = - MatchingArraySize(input_x_dims, 1, input_y_dims, 1, output_dims, 1); - const int64_t depth = - MatchingArraySize(input_x_dims, 0, input_y_dims, 0, output_dims, 0); - - TFLITE_DCHECK_EQ(rank, batches); - - int64_t offset = 0; - int64_t size = depth * height * width; - for (int64_t i = 0; i < rank; i++) { - const T* input_data = input_condition_data[i] ? input_x_data : input_y_data; - memcpy(output_data + offset, input_data + offset, size * sizeof(T)); - } -} - } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index d570dadd86b4dc7c3abe341a4955320367330b9c..f14667090f5c3867c7992211272063239f3b92aa 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -127,6 +127,10 @@ void PortableZeroVector(float* vector, int v_size); // Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); +// Check if all entries of a vector are zero. +bool PortableIsZeroVector(const float* vector, int v_size); +bool NeonIsZeroVector(const float* vector, int v_size); + // Symmetric quantizer. void PortableSymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h index e9b6baeaee87d22aef238410bc9f447509a81c47..d57739279f44ad2a9fff2bd4dca21047b8147f2a 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -76,8 +76,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOne( - acc, output_multiplier, output_shift); + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, + -output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc index 2607adc0c18aeaa8dc2061e0e95a307205700a08..f8c6f341f7e61529bbbac592f9caf115f6121e0c 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -29,9 +29,18 @@ float PortableClip(float f, float abs_limit) { return result; } +bool PortableIsZeroVector(const float* vector, int v_size) { + for (int i = 0; i < v_size; ++i) { + if (*vector++ != 0.0f) return false; + } + return true; +} + void PortableSymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, - float* max, float* scaling_factor) { + int8_t* quantized_values, + float* __restrict__ min, + float* __restrict__ max, + float* __restrict__ scaling_factor) { auto minmax = std::minmax_element(values, values + size); *min = *minmax.first; *max = *minmax.second; @@ -71,13 +80,14 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, void PortableMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, const float* scaling_factors, - int n_batch, float* __restrict__ result, int result_stride) { + const int8_t* __restrict__ vectors, + const float* __restrict__ scaling_factors, int n_batch, + float* __restrict__ result, int result_stride) { int batch, row, col; for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) { const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch]; // Get the address of the first row. - int8_t* row_ptr = (int8_t*)matrix; // NOLINT + const int8_t* row_ptr = matrix; for (row = 0; row < m_rows; ++row, result += result_stride) { // Initialize the dot product sum for the row to 0. int32_t dotprod = 0; diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index 1757a9f5e5299401fb9fef1d38870bd4e63ba3c1..d2e1fecd25cf3d11d3daffcc566dc1d5df97128c 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -25,6 +25,8 @@ namespace tensor_utils { // Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); +bool PortableIsZeroVector(const float* vector, int v_size); + void PortableSymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, float* max, float* scaling_factor); @@ -112,6 +114,10 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector, float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } +bool IsZeroVector(const float* vector, int v_size) { + return PortableIsZeroVector(vector, v_size); +} + void SymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, float* max, float* scaling_factor) { diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index e2978cfd6778a47286728f8e1f1eecda6bfcb024..f8ee554894c63bf1e551aac4b14b9a4f27039e16 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -33,8 +33,131 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { + +// TODO(b/77858996): Add these to gemmlowp. +template +IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + return a; +} + +template <> +inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t sum = a64 + b64; + return static_cast(std::min( + static_cast(std::numeric_limits::max()), + std::max( + static_cast(std::numeric_limits::min()), + sum))); +} + +template +gemmlowp::FixedPoint SaturatingAddNonGemmlowp( + gemmlowp::FixedPoint a, + gemmlowp::FixedPoint b) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingAddNonGemmlowp(a.raw(), b.raw())); +} + +template +IntegerType SaturatingSub(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + return a; +} + +template <> +inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) { + std::int32_t a32 = a; + std::int32_t b32 = b; + std::int32_t diff = a32 - b32; + return static_cast(std::min(32767, std::max(-32768, diff))); +} + +template <> +inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t diff = a64 - b64; + return static_cast(std::min( + static_cast(std::numeric_limits::max()), + std::max( + static_cast(std::numeric_limits::min()), + diff))); +} + +template +gemmlowp::FixedPoint SaturatingSub( + gemmlowp::FixedPoint a, + gemmlowp::FixedPoint b) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingSub(a.raw(), b.raw())); +} +// End section to be moved to gemmlowp. + namespace reference_ops { +// TODO(b/80247582) Remove this constant. +// This will be phased out as the shifts are revised with more thought. Use of a +// constant enables us to track progress on this work. +// +// Used mainly to convert from old-style shifts (right) to new-style (left). +static constexpr int kReverseShift = -1; + +template +int CountLeadingZeros(T integer_input) { + static_assert(std::is_unsigned::value, + "Only unsigned integer types handled."); + if (integer_input == 0) { + return std::numeric_limits::digits; + } + const T one_in_leading_positive = static_cast(1) + << (std::numeric_limits::digits - 1); + int leading_zeros = 0; + while (integer_input < one_in_leading_positive) { + integer_input <<= 1; + ++leading_zeros; + } + return leading_zeros; +} + +template +IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { + if (exponent == 0) { + return x; + } + using ScalarIntegerType = + typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; + const IntegerType min = + gemmlowp::Dup(std::numeric_limits::min()); + const IntegerType max = + gemmlowp::Dup(std::numeric_limits::max()); + const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); + + const std::int32_t threshold = + ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); + const IntegerType positive_mask = + gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); + const IntegerType negative_mask = + gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); + + IntegerType result = gemmlowp::ShiftLeft(x, exponent); + result = gemmlowp::SelectUsingMask(positive_mask, max, result); + result = gemmlowp::SelectUsingMask(negative_mask, min, result); + return result; +} + +// If we want to leave IntegerBits fixed, then multiplication +// by a power of two has to be saturating/rounding, not exact anymore. +template +gemmlowp::FixedPoint +SaturatingRoundingMultiplyByPOTParam( + gemmlowp::FixedPoint a, int exponent) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); +} + // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE // BROADCASTING. // @@ -291,8 +414,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOne( - acc, output_multiplier, output_shift); + acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( + acc, output_multiplier, kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -515,8 +638,8 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier, - output_shift); + acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( + acc, output_multiplier, kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -893,31 +1016,30 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, int32 input_zero_point, uint8* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - TFLITE_DCHECK_EQ(batches, 1); - TFLITE_DCHECK_EQ(height, 1); - TFLITE_DCHECK_EQ(width, 1); - int32 square_l2_norm = 0; - for (int i = 0; i < depth; i++) { - int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; - square_l2_norm += diff * diff; - } - int32 inv_l2norm_multiplier; - int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); - - for (int i = 0; i < depth; i++) { - int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; - int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( - 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); - int32 unclamped_output_val = 128 + rescaled_diff; - int32 output_val = std::min(255, std::max(0, unclamped_output_val)); - output_data[Offset(output_dims, i, 0, 0, 0)] = - static_cast(output_val); + const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); + for (int i = 0; i < outer_size; ++i) { + int32 square_l2_norm = 0; + for (int c = 0; c < depth; c++) { + int32 diff = + input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int c = 0; c < depth; c++) { + int32 diff = + input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( + 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + output_data[Offset(output_dims, c, i, 0, 0)] = + static_cast(output_val); + } } } @@ -983,15 +1105,17 @@ inline void Add(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1021,9 +1145,7 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, TFLITE_DCHECK_EQ(output_activation_max, 32767); } - const int flat_size = RequiredBufferSizeForDims(output_dims); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0); TFLITE_DCHECK_GE(input1_shift, 0); @@ -1139,15 +1261,17 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1192,15 +1316,17 @@ inline void BroadcastAddFivefold( const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1380,9 +1506,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, const int32 input2_val = input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; const int32 unclamped_result = - output_offset + - MultiplyByQuantizedMultiplierSmallerThanOne( - input1_val * input2_val, output_multiplier, output_shift); + output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, output_multiplier, + kReverseShift * output_shift); const int32 clamped_output = std::min(output_activation_max, std::max(output_activation_min, unclamped_result)); @@ -1399,9 +1525,7 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, int16* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Mul/Int16"); - const int flat_size = RequiredBufferSizeForDims(output_dims); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -1421,9 +1545,7 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8"); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int flat_size = RequiredBufferSizeForDims(output_dims); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); - TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -1456,33 +1578,6 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, output_data, output_dims); } -inline void Div(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int batches = - MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); - const int height = - MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); - const int width = - MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); - const int depth = - MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - ActivationFunctionWithMinMax( - input1_data[Offset(input1_dims, c, x, y, b)] / - input2_data[Offset(input2_dims, c, x, y, b)], - output_activation_min, output_activation_max); - } - } - } - } -} - // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -1524,6 +1619,18 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims, } } +inline void Div(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] / input2_data[i], output_activation_min, + output_activation_max); + } +} + inline void Sub(const float* input1_data, const Dims<4>& input1_dims, const float* input2_data, const Dims<4>& input2_dims, float output_activation_min, float output_activation_max, @@ -1615,15 +1722,17 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sub, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sub, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1818,7 +1927,7 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, // The quantization of the input, output arrays is as follows: // - The input activations are quantized as uint8 on the interval // [-1, 127/128]. -// The rationale for that is that that is the natural interval for output +// The rationale for that is that is the natural interval for output // activations (see next point) and these need to be concatenated together. // We could accommodate different ranges by re-scaling, but we empirically // found that setting the input activations range to be [-1, 127/128] in the @@ -1883,7 +1992,7 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, // However, for a fixed-point implementation in 16-bit integers, using 5 // integer bits to represent the [-16, 16] range would leave only 11 // fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive -// representable values. Notice that that is higher than the +// representable values. Notice that is higher than the // worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic. // Using [-8, 8] thus seems like the better compromise overall, enjoying // an increment of 2.4e-4 between representable values and a worst-case @@ -2664,6 +2773,121 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, } } +// Although currently the name of this function says that it cannot handle +// values less than 1, in practice it can handle as low as 1/x_max, where +// x_max is the largest representable input. In other words, the output range +// is symmetric. +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1_impl( + gemmlowp::FixedPoint input_val) { + using FixedPoint0 = gemmlowp::FixedPoint; + // The reason for accumulating the result with an extra bit of headroom is + // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * + // recip_denom will otherwise introduce an error. + static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; + using FixedPointAccum = gemmlowp::FixedPoint; + + const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1488522236, std::log(2.0)); + const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); + const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1518500250, std::sqrt(0.5)); + const FixedPoint0 one_quarter = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); + + const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1057819769, + 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); + + const FixedPointAccum shifted_quarter = + gemmlowp::Rescale(one_quarter); + + // Reinterpret the input value as Q0.31, because we will figure out the + // required shift "ourselves" instead of using, say, Rescale. + FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); + // z_a_pow_2 = input_integer_bits - z_a_headroom; + int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); + FixedPoint0 r_a_tmp = + SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); + const int32 r_a_raw = + SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); + // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); + // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, + // InputIntegerBits - z_b_headroom - 0.25); + const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), + shifted_quarter); + + // z_b is treated like z_a, but premultiplying by sqrt(0.5). + FixedPoint0 z_b = z_a * sqrt_half; + int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; + const int32 r_b_raw = + SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); + const FixedPointAccum z_b_pow_2_adj = SaturatingSub( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), + shifted_quarter); + + const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); + const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( + std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); + + const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); + FixedPoint0 q = r - sqrt_sqrt_half; + q = q + q; + + const FixedPoint0 common_sq = q * q; + const FixedPoint0 num = q * r + q * common_sq * alpha_n; + const FixedPoint0 denom_minus_one_0 = + p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; + const FixedPoint0 recip_denom = + one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); + + const FixedPointAccum num_scaled = gemmlowp::Rescale(num); + return gemmlowp::Rescale(z_pow_2_adj * log_2 + + num_scaled * recip_denom); +} + +// Minimum output bits to accommodate log of maximum input range. It actually +// does not matter if one considers, say, [-64,64] or [-64,64). +// +// For example, run this through Octave: +// [0:127; ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] +constexpr int min_log_x_output_bits(int input_bits) { + return input_bits > 90 + ? 7 + : input_bits > 44 + ? 6 + : input_bits > 21 + ? 5 + : input_bits > 10 + ? 4 + : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; +} + +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1( + gemmlowp::FixedPoint input_val) { + static_assert( + OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), + "Output integer bits must be sufficent to accommodate logs of inputs."); + return log_x_for_x_greater_than_or_equal_to_1_impl( + input_val); +} + inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, int32 input_multiplier, int32 input_left_shift, int32 reverse_scaling_divisor, @@ -2706,13 +2930,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } - // TODO(b/77858996): Implement fixed-point log(). - // Not a fully-quantized implementation: floating-point log(). - const float float_log_sum_of_exps = - std::log(static_cast(sum_of_exps.raw()) / - (1 << (31 - kAccumulationIntegerBits))); - const int32 fixed_log_sum_of_exps = static_cast(TfLiteRound( - float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits)))); + const int32 fixed_log_sum_of_exps = + log_x_for_x_greater_than_or_equal_to_1( + sum_of_exps) + .raw(); // rescaled_diff_min is smallest representable in // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the @@ -2723,9 +2944,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, fixed_log_sum_of_exps + std::numeric_limits::lowest(); const int adjusted_diff_min = std::max(diff_min - 1, // Note use of > below instead of >= above. - MultiplyByQuantizedMultiplierSmallerThanOne( + MultiplyByQuantizedMultiplierSmallerThanOneExp( rescaled_diff_min, reverse_scaling_divisor, - reverse_scaling_right_shift)); + kReverseShift * reverse_scaling_right_shift)); for (int c = 0; c < depth; ++c) { int32 input_diff = @@ -2981,9 +3202,10 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims, } } -inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, +template +inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims, const int32* output_size_data, - const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_size_dims, T* output_data, const Dims<4>& output_dims, bool align_corners) { int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); int32 input_height = ArraySize(input_dims, 2); @@ -3015,15 +3237,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, int32 x0 = static_cast(std::floor(input_x)); int32 x1 = std::min(x0 + 1, input_width - 1); for (int c = 0; c < depth; ++c) { - float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] * - (1 - (input_y - y0)) * - (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x0, y1, b)] * - (input_y - y0) * (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x1, y0, b)] * - (1 - (input_y - y0)) * (input_x - x0) + - input_data[Offset(input_dims, c, x1, y1, b)] * - (input_y - y0) * (input_x - x0); + T interpolation = + static_cast(input_data[Offset(input_dims, c, x0, y0, b)] * + (1 - (input_y - y0)) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x0, y1, b)] * + (input_y - y0) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x1, y0, b)] * + (1 - (input_y - y0)) * (input_x - x0) + + input_data[Offset(input_dims, c, x1, y1, b)] * + (input_y - y0) * (input_x - x0)); output_data[Offset(output_dims, c, x, y, b)] = interpolation; } } @@ -3036,8 +3258,18 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, const Dims<4>& output_dims) { - ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, - output_data, output_dims, /*align_corners=*/false); + ResizeBilinear(input_data, input_dims, output_size_data, + output_size_dims, output_data, output_dims, + /*align_corners=*/false); +} + +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, + output_size_dims, output_data, output_dims, + /*align_corners=*/false); } template @@ -3256,10 +3488,10 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; const int start_h = begin[2]; const int stop_h = - size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2]; + size[2] == -1 ? input_dims.sizes[2] - start_h : start_h + size[2]; const int start_w = begin[1]; const int stop_w = - size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1]; + size[1] == -1 ? input_dims.sizes[1] - start_w : start_w + size[1]; const int start_d = begin[0]; const int stop_d = size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; @@ -3284,63 +3516,124 @@ inline void Exp(const T* input_data, const size_t num_elements, } } -template -inline bool Mean(T* input_data, const int* input_dims, const int input_num_dims, - T* output_data, const int* output_dims, - const int output_num_dims, const int* axis, - const int num_axis_dimensions, bool keep_dims, int* temp_index, - int* resolved_axis, U* temp_sum) { - // resets output data. - size_t num_outputs = 1; - for (int idx = 0; idx < output_num_dims; ++idx) { - num_outputs *= static_cast(output_dims[idx]); - } - for (size_t idx = 0; idx < num_outputs; ++idx) { - output_data[idx] = T(); - temp_sum[idx] = U(); - } - // resets temp index. +// A generic reduce method that can be used for reduce_sum, reduce_mean, etc. +// It takes a reducer function as input and returns false when numeric overflow +// is detected. +// This method iterates through input data and reduce elements along the +// dimensions given in axis. +template +inline bool Reduce(const In* input_data, const int* input_dims, + const int* output_dims, const int input_num_dims, + const int output_num_dims, const int* axis, + const int num_axis, int* input_iter, + Out reducer(Out current, const In in, bool* overflow), + Out* output_data) { + // Reset input iterator. + TFLITE_DCHECK(input_num_dims > 0); for (int idx = 0; idx < input_num_dims; ++idx) { - temp_index[idx] = 0; + input_iter[idx] = 0; } - // resolves axis. - int num_resolved_axis = 0; - for (int idx = 0; idx < num_axis_dimensions; ++idx) { - int current = axis[idx]; - TFLITE_DCHECK(current < input_num_dims && current + input_num_dims >= 0); - if (current < 0) { - current += input_num_dims; - } + // Iterate through input_data. + do { + size_t input_offset = + ReducedOutputOffset(input_num_dims, input_dims, input_iter, 0, nullptr); + size_t output_offset = ReducedOutputOffset(input_num_dims, input_dims, + input_iter, num_axis, axis); + bool overflow = false; + output_data[output_offset] = reducer(output_data[output_offset], + input_data[input_offset], &overflow); + if (overflow) return false; + } while (NextIndex(input_num_dims, input_dims, input_iter)); + return true; +} + +inline bool ResolveAxis(const int num_dims, const int* axis, const int num_axis, + int* out_axis, int* out_num_axis) { + *out_num_axis = 0; // Just in case. + // o(n^2) is fine since out_num_axis should be really small, mostly <= 4 + for (int idx = 0; idx < num_axis; ++idx) { + // Handle negative index. + int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx]; + TFLITE_DCHECK(current >= 0 && current < num_dims); bool is_dup = false; - for (int j = 0; j < num_resolved_axis; ++j) { - if (resolved_axis[j] == current) { + for (int j = 0; j < *out_num_axis; ++j) { + if (out_axis[j] == current) { is_dup = true; break; } } if (!is_dup) { - resolved_axis[num_resolved_axis++] = current; + out_axis[*out_num_axis] = current; + *out_num_axis += 1; } } - // iterates through input_data. - for (bool has_next = true; has_next; - has_next = NextIndex(input_num_dims, input_dims, temp_index)) { - size_t input_offset = - ReducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr); - size_t output_offset = - ReducedOutputOffset(input_num_dims, input_dims, temp_index, - num_resolved_axis, resolved_axis); - temp_sum[output_offset] += static_cast(input_data[input_offset]); - } - // takes average by num of elements added to get mean. - size_t num_elements_in_axis = 1; + return true; +} + +// This method expects that output_data has been initialized. +template +inline bool ReduceSumImpl(const In* input_data, const int* input_dims, + const int* output_dims, const int input_num_dims, + const int output_num_dims, const int* axis, + const int num_axis, int* input_iter, + Out* output_data) { + auto reducer = [](Out current, const In in, bool* overflow) -> Out { + const Out actual_in = static_cast(in); + return current + actual_in; + }; + return Reduce(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, axis, num_axis, input_iter, reducer, + output_data); +} + +// Computes the mean of elements across dimensions given in axis. +// It does so in two stages, first calculates the sum of elements along the axis +// then divides it by the number of element in axis. +template +inline bool Mean(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int num_axis_dimensions, bool keep_dims, + int* temp_index, int* resolved_axis, U* temp_sum) { + // Reset output data. + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + size_t current = static_cast(output_dims[idx]); + // Overflow prevention. + if (num_outputs > std::numeric_limits::max() / current) { + return false; + } + num_outputs *= current; + } + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = T(); + temp_sum[idx] = U(); + } + + // Resolve axis. + int num_resolved_axis = 0; + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; + } + + if (!ReduceSumImpl(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, resolved_axis, num_resolved_axis, + temp_index, temp_sum)) { + return false; + } + + // Calculate mean by dividing output_data by num of aggregated element. + U num_elements_in_axis = 1; for (int idx = 0; idx < num_resolved_axis; ++idx) { size_t current = static_cast(input_dims[resolved_axis[idx]]); + // Overflow prevention. if (current > (std::numeric_limits::max() / num_elements_in_axis)) { return false; } num_elements_in_axis *= current; } + if (num_elements_in_axis > 0) { for (size_t idx = 0; idx < num_outputs; ++idx) { output_data[idx] = @@ -3470,7 +3763,6 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, T2* output_data, const Dims<4>& output_dims) { // The current ArgMax implemention can only determine the index of the maximum // value in the last dimension. So the axis argument is ignored. - TFLITE_DCHECK_EQ(axis[0], 3); // For ArgMax, the number of output dimensions = (number of input dimensions - // 1). For the sake of simplicity, the output dimensions are equal to the @@ -3529,8 +3821,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, int pad_height, float* output_data, const Dims<4>& output_dims) { const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); const int input_height = ArraySize(input_dims, 2); const int input_width = ArraySize(input_dims, 1); const int filter_height = ArraySize(filter_dims, 2); @@ -3545,7 +3837,7 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, // computing their influence on the output, rather than looping through the // output elements in the typical "gather" access pattern of a conv. We // therefore must initialize the output array to zero. - for (int i = 0; i < RequiredBufferSizeForDims(output_dims); i++) { + for (int i = 0; i < FlatSize(output_dims); i++) { output_data[i] = 0.0f; } @@ -3570,8 +3862,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, float input_value = input_data[Offset(input_dims, in_channel, in_x, in_y, batch)]; float filter_value = - filter_data[Offset(filter_dims, out_channel, filter_x, - filter_y, in_channel)]; + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] += input_value * filter_value; } @@ -3584,6 +3876,16 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } } +template +inline bool EqualFn(T lhs, T rhs) { + return lhs == rhs; +} + +template +inline bool NotEqualFn(T lhs, T rhs) { + return lhs != rhs; +} + template inline bool GreaterFn(T lhs, T rhs) { return lhs > rhs; @@ -3608,20 +3910,14 @@ template F> inline void Comparison(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, bool* output_data, const Dims<4>& output_dims) { - const int64_t batches = - MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); - const int64_t height = - MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); - const int64_t width = - MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); - const int64_t depth = - MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); - for (int64_t i = 0; i < batches * height * width * depth; ++i) { + const int64_t flatsize = + MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int64_t i = 0; i < flatsize; ++i) { output_data[i] = F(input1_data[i], input2_data[i]); } } -template F> +template F> inline void Comparison(int left_shift, const T* input1_data, const Dims<4>& input1_dims, int32 input1_offset, int32 input1_multiplier, int input1_shift, @@ -3629,23 +3925,21 @@ inline void Comparison(int left_shift, const T* input1_data, int32 input2_offset, int32 input2_multiplier, int input2_shift, bool* output_data, const Dims<4>& output_dims) { - const int64_t batches = - MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); - const int64_t height = - MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); - const int64_t width = - MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); - const int64_t depth = - MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); - for (int64_t i = 0; i < batches * height * width * depth; ++i) { + const int64_t flatsize = + MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int64_t i = 0; i < flatsize; ++i) { const int32 input1_val = input1_offset + input1_data[i]; const int32 input2_val = input2_offset + input2_data[i]; const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); - const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); output_data[i] = F(scaled_input1_val, scaled_input2_val); } } @@ -3672,7 +3966,7 @@ inline void BroadcastComparison(const T* input1_data, } } -template F> +template F> inline void BroadcastComparison(int left_shift, const T* input1_data, const Dims<4>& input1_dims, int32 input1_offset, int32 input1_multiplier, int input1_shift, @@ -3694,11 +3988,13 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); output_data[Offset(output_dims, c, x, y, b)] = F(scaled_input1_val, scaled_input2_val); } @@ -3724,11 +4020,11 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, int32 input2_multiplier, int input2_shift, bool* output_data, \ const Dims<4>& output_dims) { \ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \ - BroadcastComparison(left_shift, input1_data, input1_dims, \ - input1_offset, input1_multiplier, \ - input1_shift, input2_data, input2_dims, \ - input2_offset, input2_multiplier, \ - input2_shift, output_data, output_dims); \ + Comparison(left_shift, input1_data, input1_dims, \ + input1_offset, input1_multiplier, input1_shift, \ + input2_data, input2_dims, input2_offset, \ + input2_multiplier, input2_shift, output_data, \ + output_dims); \ } \ template \ inline void Broadcast##name( \ @@ -3753,6 +4049,8 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, input2_offset, input2_multiplier, \ input2_shift, output_data, output_dims); \ } +TFLITE_COMPARISON_OP(Equal); +TFLITE_COMPARISON_OP(NotEqual); TFLITE_COMPARISON_OP(Greater); TFLITE_COMPARISON_OP(GreaterEqual); TFLITE_COMPARISON_OP(Less); @@ -3765,19 +4063,9 @@ inline void Select(const D* input_condition_data, const Dims<4>& input_x_dims, const T* input_y_data, const Dims<4>& input_y_dims, T* output_data, const Dims<4>& output_dims) { - const int64_t batches = - MatchingArraySize(input_condition_dims, 3, input_x_dims, 3, input_y_dims, - 3, output_dims, 3); - const int64_t height = - MatchingArraySize(input_condition_dims, 2, input_x_dims, 2, input_y_dims, - 2, output_dims, 2); - const int64_t width = MatchingArraySize(input_condition_dims, 1, input_x_dims, - 1, input_y_dims, 1, output_dims, 1); - const int64_t depth = MatchingArraySize(input_condition_dims, 0, input_x_dims, - 0, input_y_dims, 0, output_dims, 0); - - const int64_t num_elements = batches * height * width * depth; - for (int64_t i = 0; i < num_elements; ++i) { + const int64_t flatsize = + MatchingFlatSize(input_x_dims, input_y_dims, output_dims); + for (int64_t i = 0; i < flatsize; ++i) { output_data[i] = input_condition_data[i] ? input_x_data[i] : input_y_data[i]; } @@ -3789,25 +4077,52 @@ inline void RankOneSelect(const D* input_condition_data, const T* input_x_data, const Dims<4>& input_x_dims, const T* input_y_data, const Dims<4>& input_y_dims, T* output_data, const Dims<4>& output_dims) { - const int64_t rank = ArraySize(input_condition_dims, 0); - - const int64_t batches = - MatchingArraySize(input_x_dims, 3, input_y_dims, 3, output_dims, 3); - const int64_t height = - MatchingArraySize(input_x_dims, 2, input_y_dims, 2, output_dims, 2); - const int64_t width = - MatchingArraySize(input_x_dims, 1, input_y_dims, 1, output_dims, 1); - const int64_t depth = - MatchingArraySize(input_x_dims, 0, input_y_dims, 0, output_dims, 0); - - TFLITE_DCHECK_EQ(rank, batches); + const int64_t rank = MatchingArraySize(input_condition_dims, 0, input_x_dims, + 3, input_y_dims, 3, output_dims, 3); + const int64_t inner_size = + MatchingFlatSizeSkipDim(input_x_dims, 3, input_y_dims, output_dims); int64_t offset = 0; - int64_t size = depth * height * width; for (int64_t i = 0; i < rank; i++) { const T* input_data = input_condition_data[i] ? input_x_data : input_y_data; - memcpy(output_data + offset, input_data + offset, size * sizeof(T)); - offset += size; + memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T)); + offset += inner_size; + } +} + +// For easy implementation, the indices is always a vector of size-4 vectors. +template +inline void SparseToDense(const std::vector>& indices, + const T* values, T default_value, T* output_data, + const Dims<4>& output_dims, bool value_is_scalar) { + const int value_count = indices.size(); + + // First fill the output_data with default value. + const int num_elements = FlatSize(output_dims); + for (int i = 0; i < num_elements; ++i) { + output_data[i] = default_value; + } + + // Special handle for value is scalar case to avoid checking the boolean + // condition within the loop every time. + if (value_is_scalar) { + for (int i = 0; i < value_count; ++i) { + const std::vector& index = indices[i]; + TFLITE_DCHECK_EQ(index.size(), 4); + const T value = *values; // just use the first value. + output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = + value; + } + return; + } + + // Go through the values and indices to fill the sparse values. + for (int i = 0; i < value_count; ++i) { + const std::vector& index = indices[i]; + TFLITE_DCHECK_EQ(index.size(), 4); + const T value = values[i]; + output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = + value; } } diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d8765f11b2941ef5871c7db8e3582e506713aa6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace { +template +void TestOneResizeBilinear(int batch, int depth, int input_width, + int input_height, int output_width, + int output_height, float error_threshold) { + Dims<4> input_dims_inference = + MakeDimsForInference(depth, input_width, input_height, batch); + Dims<4> output_dims_inference = + MakeDimsForInference(depth, output_width, output_height, batch); + + const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference); + const int output_buffer_size = + RequiredBufferSizeForDims(output_dims_inference); + + std::vector input_data(input_buffer_size, 0); + std::vector reference_output_data(output_buffer_size, 0); + // Initialize the output data with something other than zero, so we can catch + // issue with kernels failing to initialize the output. + std::vector output_data(output_buffer_size, 3); + + const T min_amplitude = static_cast(0); + const T max_amplitude = static_cast(255); + FillRandom(&input_data, min_amplitude, max_amplitude); + + Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1); + std::vector output_size_data = {output_height, output_width}; + + reference_ops::ResizeBilinear( + input_data.data(), input_dims_inference, output_size_data.data(), + output_size_dims, reference_output_data.data(), output_dims_inference); + optimized_ops::ResizeBilinear(input_data.data(), input_dims_inference, + output_size_data.data(), output_size_dims, + output_data.data(), output_dims_inference); + + double sum_diff = 0; + float max_abs_val = 0; + for (int i = 0; i < output_buffer_size; i++) { + sum_diff += std::abs(static_cast(output_data[i]) - + static_cast(reference_output_data[i])); + max_abs_val = std::max( + max_abs_val, std::abs(static_cast(reference_output_data[i]))); + } + + if (sum_diff != 0.f) { + const float mean_diff = static_cast(sum_diff / output_buffer_size); + const float relative_error = std::abs(mean_diff) / max_abs_val; + ASSERT_LT(relative_error, error_threshold); + } +} + +TEST(ResizeBilinear, TestResizeBilinear8Bit) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 0.025); + } +} + +TEST(ResizeBilinear2x2, TestResizeBilinear8Bit) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = input_width * 2; + const int output_height = input_height * 2; + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); + } +} + +TEST(ResizeBilinear, TestResizeBilinear) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); + } +} + +TEST(ResizeBilinear2x2, TestResizeBilinear) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = input_width * 2; + const int output_height = input_height * 2; + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d781a7b642036f3c5ddaa366f257fe26511c83c3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc @@ -0,0 +1,227 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" + +namespace tflite { +namespace { + +void RunSoftmaxFloatReference(const uint8* input_data, + const Dims<4>& dims_common, int32 input_offset, + const double input_scale, int stride, float beta, + uint8* reference_output_data) { + const int ref_buffer_size = RequiredBufferSizeForDims(dims_common); + std::vector reference_dequant_data(ref_buffer_size); + std::vector reference_output_float_data(ref_buffer_size); + + // Reference data generated via Dequant of input into float, and then applying + // float Softmax. + reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale, + reference_dequant_data.data(), dims_common); + optimized_ops::Softmax(reference_dequant_data.data(), dims_common, beta, + reference_output_float_data.data(), dims_common); + // Work with quantized scaling for Softmax, under which 256 represents 1, but + // we limit this to 255. + for (int i = 0; i < ref_buffer_size; i++) { + reference_output_data[i] = std::min( + 255, + static_cast(std::round(256.0f * reference_output_float_data[i]))); + } +} + +void CheckOutputData(const uint8* test_output, const uint8* reference_output, + const Dims<4>& dims_common, const string& check_label, + bool be_exacting) { + const int buffer_size = RequiredBufferSizeForDims(dims_common); + // While calculating some metrics in floating point, we work with quantized + // scaling. + std::vector diff(buffer_size); + int64_t sum_diff = 0; + int64_t sum_abs_diff = 0; + for (int i = 0; i < buffer_size; i++) { + diff[i] = static_cast(test_output[i]) - reference_output[i]; + sum_diff += diff[i]; + sum_abs_diff += std::abs(diff[i]); + } + // These stats help understand test failures. + std::sort(std::begin(diff), std::end(diff)); + const int min_diff = diff.front(); + const int max_diff = diff.back(); + const int median_diff = diff[diff.size() / 2]; + const float mean_diff = static_cast(sum_diff) / buffer_size; + const float mean_abs_diff = static_cast(sum_abs_diff) / buffer_size; + // We either check for bit exactness (against the reference quantized version) + // or for general accuracy, allowing off-by-one (against the float reference). + if (be_exacting) { + ASSERT_TRUE(std::abs(min_diff) == 0 && std::abs(max_diff) == 0); + } else { + // For small numbers of samples, the estimates of the means vary more. + // Rather than widen the tolerances, we skip the smaller tests. + ASSERT_TRUE(((std::abs(mean_diff) < 2e-2f && mean_abs_diff < 3e-2f) || + buffer_size < 10000) && + std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 && + std::abs(max_diff) <= 1); + } +} + +// Runs the Softmax and compares against the float reference implementation and +// the quantized reference implementation. +void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, + int32 input_offset, const double input_scale, int stride, + float beta) { + const int buffer_size = RequiredBufferSizeForDims(dims_common); + std::vector optimized_softmax_output(buffer_size); + std::vector reference_float_softmax_output(buffer_size); + std::vector reference_quant_softmax_output(buffer_size); + + RunSoftmaxFloatReference(input_data, dims_common, input_offset, input_scale, + stride, beta, reference_float_softmax_output.data()); + + int32 input_beta_multiplier; + int input_beta_left_shift; + static const int kScaledDiffIntegerBits = 5; + tflite::PreprocessSoftmaxScaling(beta, input_scale, kScaledDiffIntegerBits, + &input_beta_multiplier, + &input_beta_left_shift); + // diff_min has a negative value, and is used to limit the maximum magnitude + // of the diffs, which are <= 0. + const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, + input_beta_left_shift); + + optimized_ops::Softmax(input_data, dims_common, input_beta_multiplier, + input_beta_left_shift, diff_min, + optimized_softmax_output.data(), dims_common); + reference_ops::Softmax(input_data, dims_common, input_beta_multiplier, + input_beta_left_shift, diff_min, + reference_quant_softmax_output.data(), dims_common); + + CheckOutputData(optimized_softmax_output.data(), + reference_float_softmax_output.data(), dims_common, + "Optimized vs float reference", false); + CheckOutputData(optimized_softmax_output.data(), + reference_quant_softmax_output.data(), dims_common, + "Optimized vs quant reference", true); + CheckOutputData(reference_quant_softmax_output.data(), + reference_float_softmax_output.data(), dims_common, + "Quant reference vs float reference", false); +} + +// This function picks some random Softmax params, which are checked for +// desirability. If not acceptable, it returns false. If they're OK, +// it runs the Softmax test and returns true. This allows the caller +// to loop until a test has been run. +// +// Currently we do not reject for any reason. +bool TryOneUniformSoftmax() { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // Softmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); + + Dims<4> dims_common = + MakeDimsForInference(input_depth, input_width, input_height, batch); + const int buffer_size = RequiredBufferSizeForDims(dims_common); + + std::vector input_data(buffer_size); + FillRandom(&input_data); + RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale, + stride, beta); + return true; +} + +// See TryOneUniformSoftmax() for a general description. +// +// Tests with "skyscraper" input patterns are included for two reasons. (a) +// Bimodal distributions are potentially challenging and perhaps more +// realistic than simple uniform random inputs. (b) Some implementations of +// Softmax may adapt as they traverse the depth, and so we test handling of +// cases where relatively small values are encountered at the beginning and end. +bool TryOneSkyscraperSoftmax(bool small_depth) { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // Softmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = small_depth + ? ExponentialRandomPositiveInt(0.75f, 40, 500) + : ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); + // Extra parameters for skyscraper input patterns. + const double middle_proportion = + ExponentialRandomPositiveFloat(0.65f, 0.1, 1.0); + const int middle_min = UniformRandomInt(0, 255); + const int sides_max = UniformRandomInt(0, middle_min); + + Dims<4> dims_common = + MakeDimsForInference(input_depth, input_width, input_height, batch); + const int buffer_size = RequiredBufferSizeForDims(dims_common); + + std::vector input_data(buffer_size); + FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, + sides_max); + RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale, + stride, beta); + return true; +} + +TEST(TestQuantizedSoftmax, UniformSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneUniformSoftmax()) { + } + } +} + +TEST(TestQuantizedSoftmax, SkyscraperSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperSoftmax(false)) { + } + } +} + +TEST(TestQuantizedSoftmax, SmallSkyscraperSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperSoftmax(true)) { + } + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index 62cea143e6afc0631493012be26808a89eb03138..ce887cea8b794b4b0cfd31722581cf9327be625e 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -49,6 +49,34 @@ inline bool* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.b : nullptr; } +template +inline const T* GetTensorData(const TfLiteTensor* tensor); + +template <> +inline const float* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.f : nullptr; +} + +template <> +inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.uint8 : nullptr; +} + +template <> +inline const int32_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i32 : nullptr; +} + +template <> +inline const int64_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i64 : nullptr; +} + +template <> +inline const bool* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.b : nullptr; +} + inline int RemapDim(int max_dimensions, int d) { return max_dimensions - d - 1; } diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index e1c9ccd84b09fdca0241192ef4dff62ded096433..5160e22307ae0894fabd0e9c4f7b9cd38b00840e 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -23,6 +23,9 @@ namespace tensor_utils { // Limit a float input f between +abs_limit and -abs_limit. float Clip(float f, float abs_limit); +// Checks if all entries of vector are zero. +bool IsZeroVector(const float* vector, int v_size); + // Quantizes a buffer of floating point values using a symmetric quantization // (i.e. linear quantization without an offset) to 8-bit signed integers. // It also outputs the range (min, max) of the floating point buffer, and the diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc index 3d8a2eada0c30132b5646327da5cd9fd80ccd39f..14ee528394b6872d9e79969db0e431658277f56b 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -32,6 +32,25 @@ TEST(uKernels, ClipTest) { {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0}))); } +TEST(uKernels, IsZeroTest) { + constexpr int kVectorSize = 21; + static float zeros[kVectorSize] = {0.0}; + EXPECT_TRUE(IsZeroVector(zeros, kVectorSize)); + + static float nonzeros[kVectorSize] = { + 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, + 1e-13, 1e-14, 1e-15, 1e-16, 1e-17, 1e-18, 1e-19, + 1e-20, 1e-21, 1e-22, 1e-23, 1e-24, 1e-25, 1e-26}; + EXPECT_FALSE(IsZeroVector(nonzeros, kVectorSize)); +} + +TEST(uKernels, GeneratedIsZeroTest) { + constexpr int kVectorSize = 39; + std::vector input(kVectorSize); + ZeroVector(input.data(), kVectorSize); + EXPECT_TRUE(IsZeroVector(input.data(), kVectorSize)); +} + TEST(uKernels, SymmetricQuantizeFloatsTest) { constexpr int kVectorSize = 9; static float input[kVectorSize] = {-640, -635.0, -630, 10.0, 2.0, diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.cc b/tensorflow/contrib/lite/kernels/internal/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b1fd9b344d99103ee2a1b5b95fd697ccf4a64d0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/test_util.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" + +#include +#include + +namespace tflite { + +Dims<4> MakeDimsForInference(int depth, int width, int height, int batch) { + Dims<4> result; + int cum_prod = 1; + + result.sizes[0] = depth; + result.strides[0] = cum_prod; + cum_prod *= result.sizes[0]; + + result.sizes[1] = width; + result.strides[1] = cum_prod; + cum_prod *= result.sizes[1]; + + result.sizes[2] = height; + result.strides[2] = cum_prod; + cum_prod *= result.sizes[2]; + + result.sizes[3] = batch; + result.strides[3] = cum_prod; + + return result; +} + +// this is a copied from an internal function in propagate_fixed_sizes.cc +bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width, + int filter_height, int stride, PaddingType padding_type, + Dims<4>* output_dims, int* pad_width, int* pad_height) { + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int batch = ArraySize(input_dims, 3); + + int output_height = 0; + int output_width = 0; + if (padding_type == PaddingType::kValid) { + output_height = (input_height + stride - filter_height) / stride; + output_width = (input_width + stride - filter_width) / stride; + } else if (padding_type == PaddingType::kSame) { + output_height = (input_height + stride - 1) / stride; + output_width = (input_width + stride - 1) / stride; + } else { + return false; + } + + if (output_width <= 0 || output_height <= 0) { + return false; + } + + *pad_height = + ((output_height - 1) * stride + filter_height - input_height) / 2; + *pad_width = ((output_width - 1) * stride + filter_width - input_width) / 2; + *output_dims = + MakeDimsForInference(output_depth, output_width, output_height, batch); + return true; +} + +std::mt19937& RandomEngine() { + static std::mt19937 engine; + return engine; +} + +int UniformRandomInt(int min, int max) { + std::uniform_int_distribution dist(min, max); + return dist(RandomEngine()); +} + +float UniformRandomFloat(float min, float max) { + std::uniform_real_distribution dist(min, max); + return dist(RandomEngine()); +} + +int ExponentialRandomPositiveInt(float percentile, int percentile_val, + int max_val) { + const float lambda = + -std::log(1.f - percentile) / static_cast(percentile_val); + std::exponential_distribution dist(lambda); + float val; + do { + val = dist(RandomEngine()); + } while (!val || !std::isfinite(val) || val > max_val); + return static_cast(std::ceil(val)); +} + +float ExponentialRandomPositiveFloat(float percentile, float percentile_val, + float max_val) { + const float lambda = + -std::log(1.f - percentile) / static_cast(percentile_val); + std::exponential_distribution dist(lambda); + float val; + do { + val = dist(RandomEngine()); + } while (!std::isfinite(val) || val > max_val); + return val; +} + +void FillRandom(std::vector* vec, float min, float max) { + std::uniform_real_distribution dist(min, max); + auto gen = std::bind(dist, RandomEngine()); + std::generate(std::begin(*vec), std::end(*vec), gen); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.h b/tensorflow/contrib/lite/kernels/internal/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..26078cef49a7868c64ff0095898eebe5e4de8751 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/test_util.h @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +// Creates a Dims struct from a set of dimensions. +Dims<4> MakeDimsForInference(int depth, int width, int height, int batch); + +// Computes output and padding dimensions. +bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width, + int filter_height, int stride, PaddingType padding_type, + Dims<4>* output_dims, int* pad_width, int* pad_height); + +// Returns a mt19937 random engine. +std::mt19937& RandomEngine(); + +// Returns a random integer uniformly distributed between |min| and |max|. +int UniformRandomInt(int min, int max); + +// Returns a random float uniformly distributed between |min| and |max|. +float UniformRandomFloat(float min, float max); + +// Returns a random element in |v|. +template +const T& RandomElement(const std::vector& v) { + return v[UniformRandomInt(0, v.size() - 1)]; +} + +// Returns a random exponentially distributed integer. +int ExponentialRandomPositiveInt(float percentile, int percentile_val, + int max_val); + +// Returns a random exponentially distributed float. +float ExponentialRandomPositiveFloat(float percentile, float percentile_val, + float max_val); + +// Fills a vector with random floats between |min| and |max|. +void FillRandom(std::vector* vec, float min, float max); + +// Fills a vector with random numbers between |min| and |max|. +template +void FillRandom(std::vector* vec, T min, T max) { + std::uniform_int_distribution dist(min, max); + auto gen = std::bind(dist, RandomEngine()); + std::generate(std::begin(*vec), std::end(*vec), gen); +} + +// Fills a vector with random numbers. +template +void FillRandom(std::vector* vec) { + FillRandom(vec, std::numeric_limits::min(), std::numeric_limits::max()); +} + +template +void FillRandom(typename std::vector::iterator begin_it, + typename std::vector::iterator end_it, T min, T max) { + std::uniform_int_distribution dist(min, max); + auto gen = std::bind(dist, RandomEngine()); + std::generate(begin_it, end_it, gen); +} + +// Fill with a "skyscraper" pattern, in which there is a central section (across +// the depth) with higher values than the surround. +template +void FillRandomSkyscraper(std::vector* vec, int depth, + double middle_proportion, uint8 middle_min, + uint8 sides_max) { + for (auto base_it = std::begin(*vec); base_it != std::end(*vec); + base_it += depth) { + auto left_it = base_it + std::ceil(0.5 * depth * (1.0 - middle_proportion)); + auto right_it = + base_it + std::ceil(0.5 * depth * (1.0 + middle_proportion)); + FillRandom(base_it, left_it, std::numeric_limits::min(), sides_max); + FillRandom(left_it, right_it, middle_min, std::numeric_limits::max()); + FillRandom(right_it, base_it + depth, std::numeric_limits::min(), + sides_max); + } +} + +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 3290c364c18224edb733c177ad72bf86b6892434..3ecef15271c6eb9dd9d6dd370377fddda2723fcf 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,15 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#include +#include + #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" namespace tflite { enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; +enum class PaddingType { kNone, kSame, kValid }; // Quantization parameters, determining the mapping of quantized values // to real values (i.e. determining how quantized values are mathematically @@ -43,6 +47,121 @@ struct Dims { int strides[N]; }; +class RuntimeShape { + public: + // Shapes with dimensions up to 4 are stored directly in the structure, while + // larger shapes are separately allocated. + static constexpr int kMaxSmallSize = 4; + + RuntimeShape() : size_(0) {} + + explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) { + if (dimensions_count > kMaxSmallSize) { + dims_pointer_ = new int32[dimensions_count]; + } + } + + RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) { + ReplaceWith(dimensions_count, dims_data); + } + + ~RuntimeShape() { + if (size_ > kMaxSmallSize) { + delete[] dims_pointer_; + } + } + + inline int32 DimensionsCount() const { return size_; } + inline int32 Dims(int i) const { + TFLITE_DCHECK_GE(i, 0); + TFLITE_DCHECK_LT(i, size_); + return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i]; + } + inline void SetDim(int i, int32 val) { + TFLITE_DCHECK_GE(i, 0); + TFLITE_DCHECK_LT(i, size_); + if (size_ > kMaxSmallSize) { + dims_pointer_[i] = val; + } else { + dims_[i] = val; + } + } + inline int32* DimsData() { + return size_ > kMaxSmallSize ? dims_pointer_ : dims_; + } + inline const int32* DimsData() const { + return size_ > kMaxSmallSize ? dims_pointer_ : dims_; + } + + inline void Resize(int dimensions_count) { + if (size_ > kMaxSmallSize) { + delete[] dims_pointer_; + } + size_ = dimensions_count; + if (dimensions_count > kMaxSmallSize) { + dims_pointer_ = new int32[dimensions_count]; + } + } + + inline void ReplaceWith(int dimensions_count, const int32* dims_data) { + Resize(dimensions_count); + int32* dst_dims = DimsData(); + std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32)); + } + + template + inline void BuildFrom(const T& src_iterable) { + const int dimensions_count = + std::distance(src_iterable.begin(), src_iterable.end()); + Resize(dimensions_count); + int32* data = DimsData(); + for (auto it : src_iterable) { + *data = it; + ++data; + } + } + + inline void BuildFrom(const std::initializer_list init_list) { + BuildFrom>(init_list); + } + + // Returns the total count of elements, that is the size when flattened into a + // vector. + inline int FlatSize() const { + int buffer_size = 1; + const int* dims_data = DimsData(); + for (int i = 0; i < size_; i++) { + const int dim = dims_data[i]; + TFLITE_DCHECK_GE(dim, 1); + buffer_size *= dim; + } + return buffer_size; + } + + private: + int32 size_; + union { + int32 dims_[kMaxSmallSize]; + int32* dims_pointer_; + }; +}; + +// Converts inference-style shape to legacy tflite::Dims<4>. +inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) { + tflite::Dims<4> result; + const int dimensions_count = array_shape.DimensionsCount(); + TFLITE_CHECK_LE(dimensions_count, 4); + int cum_prod = 1; + for (int i = 0; i < 4; i++) { + const int new_dim = + (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1; + result.sizes[i] = new_dim; + result.strides[i] = cum_prod; + cum_prod *= new_dim; + } + return result; +} + // Gets next index to iterate through a multidimensional array. inline bool NextIndex(const int num_dims, const int* dims, int* current) { TFLITE_DCHECK_GT(num_dims, 0); @@ -132,11 +251,11 @@ int MatchingArraySize(const ArrayType1& array1, int index1, template inline int FlatSize(const Dims& dims) { - int max_offset = 0; - for (int i = 0; i < N; i++) { - max_offset += (dims.sizes[i] - 1) * dims.strides[i]; + int flat_size = 1; + for (int i = 0; i < N; ++i) { + flat_size *= dims.sizes[i]; } - return max_offset + 1; + return flat_size; } // Deprecated. Prefer FlatSize. @@ -148,7 +267,7 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) { // arrays. template inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0) { - for (int i = 0; i < N; i++) { + for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } return FlatSize(dims); @@ -157,7 +276,7 @@ inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0) { template inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0, const Dims& check_dims_1) { - for (int i = 0; i < N; i++) { + for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } return MatchingFlatSize(dims, check_dims_1); @@ -167,7 +286,7 @@ template inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0, const Dims& check_dims_1, const Dims& check_dims_2) { - for (int i = 0; i < N; i++) { + for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } return FlatSize(dims, check_dims_1, check_dims_2); @@ -178,7 +297,7 @@ inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0, const Dims& check_dims_1, const Dims& check_dims_2, const Dims& check_dims_3) { - for (int i = 0; i < N; i++) { + for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } return FlatSize(dims, check_dims_1, check_dims_2, check_dims_3); @@ -191,7 +310,7 @@ template inline int FlatSizeSkipDim(const Dims& dims, int skip_dim) { TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N); int flat_size = 1; - for (int i = 0; i < N; i++) { + for (int i = 0; i < N; ++i) { flat_size *= (i == skip_dim) ? 1 : dims.sizes[i]; } return flat_size; @@ -201,7 +320,7 @@ inline int FlatSizeSkipDim(const Dims& dims, int skip_dim) { template inline int MatchingFlatSizeSkipDim(const Dims& dims, int skip_dim, const Dims& check_dims_0) { - for (int i = 0; i < N; i++) { + for (int i = 0; i < N; ++i) { if (i != skip_dim) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } @@ -213,7 +332,7 @@ template inline int MatchingFlatSizeSkipDim(const Dims& dims, int skip_dim, const Dims& check_dims_0, const Dims& check_dims_1) { - for (int i = 0; i < N; i++) { + for (int i = 0; i < N; ++i) { if (i != skip_dim) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } @@ -226,7 +345,7 @@ inline int MatchingFlatSizeSkipDim(const Dims& dims, int skip_dim, const Dims& check_dims_0, const Dims& check_dims_1, const Dims& check_dims_2) { - for (int i = 0; i < N; i++) { + for (int i = 0; i < N; ++i) { if (i != skip_dim) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } @@ -240,7 +359,7 @@ inline int MatchingFlatSizeSkipDim(const Dims& dims, int skip_dim, const Dims& check_dims_1, const Dims& check_dims_2, const Dims& check_dims_3) { - for (int i = 0; i < N; i++) { + for (int i = 0; i < N; ++i) { if (i != skip_dim) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } @@ -259,6 +378,14 @@ bool IsPackedWithoutStrides(const Dims& dims) { return true; } +template +void ComputeStrides(Dims* dims) { + dims->strides[0] = 1; + for (int d = 1; d < N; d++) { + dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1]; + } +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc index 955e8c5764c6adad37a0009f4ddf8accb437b174..184028427fb193aa99cf155961c16eda1298e326 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.cc +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -22,9 +22,12 @@ limitations under the License. namespace tflite { -TfLiteStatus GetQuantizedConvolutionMultipler( - TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter, - TfLiteTensor* bias, TfLiteTensor* output, double* multiplier) { +TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, + TfLiteTensor* output, + double* multiplier) { const double input_product_scale = input->params.scale * filter->params.scale; const double bias_scale = bias->params.scale; const double output_scale = output->params.scale; @@ -34,7 +37,6 @@ TfLiteStatus GetQuantizedConvolutionMultipler( TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <= 1e-6 * std::min(input_product_scale, bias_scale)); TF_LITE_ENSURE(context, input_product_scale >= 0); - TF_LITE_ENSURE(context, input_product_scale < output_scale); *multiplier = input_product_scale / output_scale; @@ -87,13 +89,13 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation, } } -bool HaveSameShapes(TfLiteTensor* input1, TfLiteTensor* input2) { +bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) { return TfLiteIntArrayEqual(input1->dims, input2->dims); } TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, - TfLiteTensor* input1, - TfLiteTensor* input2, + const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteIntArray** output_shape) { int64_t dims1 = NumDimensions(input1); int64_t dims2 = NumDimensions(input2); diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index e225443a67b2ac6fb67bfd7e0828417da4ed4dab..82cded36f2ed2777daccafee5890f47c0d7254e8 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -24,8 +24,8 @@ inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; } inline int SizeOfDimension(const TfLiteTensor* t, int dim) { return t->dims->data[dim]; } -inline TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node, - int index) { +inline const TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node, + int index) { return &context->tensors[node->inputs->data[index]]; } inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node, @@ -47,8 +47,9 @@ inline int64_t NumElements(const TfLiteTensor* t) { return count; } -inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, - const TfLiteNode* node, int index) { +inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, + const TfLiteNode* node, + int index) { const bool use_tensor = node->inputs->data[index] != kOptionalTensor; if (use_tensor) { return &context->tensors[node->inputs->data[index]]; @@ -78,9 +79,12 @@ inline void SetTensorToDynamic(TfLiteTensor* tensor) { // Calculates the multiplication factor for a quantized convolution (or // quantized depthwise convolution) involving the given tensors. Returns an // error if the scales of the tensors are not compatible. -TfLiteStatus GetQuantizedConvolutionMultipler( - TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter, - TfLiteTensor* bias, TfLiteTensor* output, double* multiplier); +TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, + TfLiteTensor* output, + double* multiplier); // Calculates the useful range of an activation layer given its activation // tensor. @@ -92,13 +96,13 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation, float* activation_max); // Return true if the given tensors have the same shape. -bool HaveSameShapes(TfLiteTensor* input1, TfLiteTensor* input2); +bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2); // Calculate the output_shape that is necessary for element-wise operations // with broadcasting involving the two input tensors. TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, - TfLiteTensor* input1, - TfLiteTensor* input2, + const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteIntArray** output_shape); } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/kernel_util_test.cc b/tensorflow/contrib/lite/kernels/kernel_util_test.cc index c65b68970f6853e17af3a70aad7a2bc982a1ee60..bf6f249acc85ee050681eb6f33067be2a1aa037e 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util_test.cc +++ b/tensorflow/contrib/lite/kernels/kernel_util_test.cc @@ -33,7 +33,7 @@ class KernelUtilTest : public ::testing::Test { tensor1_.allocation_type = kTfLiteMmapRo; tensor2_.allocation_type = kTfLiteMmapRo; } - ~KernelUtilTest() { + ~KernelUtilTest() override { TfLiteTensorFree(&tensor1_); TfLiteTensorFree(&tensor2_); } diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc index e67f4e06f3680f8c9447a9e831b63415994ea176..3205c1cc52724207904621a5870636841ef379fe 100644 --- a/tensorflow/contrib/lite/kernels/l2norm.cc +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -40,7 +40,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE(context, NumDimensions(input) <= 4); @@ -64,7 +64,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { @@ -94,7 +94,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } #undef TF_LITE_L2NORM } else { - context->ReportError(context, "Inputs and outputs not all float types."); + context->ReportError(context, "Output type is %d, requires float.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc index 042314ccf55cb6de12c743448fbe040f35e7baab..070ed60040997f18f7e8053acc9532adc2377400 100644 --- a/tensorflow/contrib/lite/kernels/l2norm_test.cc +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -67,7 +67,7 @@ class L2NormOpModel : public SingleOpModel { int output_; }; -TEST(L2NormOpTest, SimpleTest) { +TEST(L2NormOpTest, SimpleFloatTest) { L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE); m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); @@ -76,6 +76,23 @@ TEST(L2NormOpTest, SimpleTest) { ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); } +TEST(L2NormOpTest, MultipleBatchFloatTest) { + L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32, + ActivationFunctionType_NONE); + m.SetInput({ + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({ + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3 + })); +} + TEST(L2NormOpTest, SimpleUint8Test) { L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE); @@ -88,6 +105,32 @@ TEST(L2NormOpTest, SimpleUint8Test) { ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1))); } +TEST(L2NormOpTest, MultipleBatchUint8Test) { + L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE); + + m.QuantizeAndPopulate(m.input(), + { + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({ + 58, 166, 173, 205, 83, 134, // batch 1 + 58, 166, 173, 205, 83, 134, // batch 2 + 58, 166, 173, 205, 83, 134, // batch 3 + })); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3 + }, + 0.1))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc index c1c70d0dfa0050dee3815aa15f5d16d2e7ddc721..36dca299d0e07a84af60a13dfeb50b0f8fe38ee2 100644 --- a/tensorflow/contrib/lite/kernels/local_response_norm.cc +++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc @@ -38,7 +38,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); @@ -60,7 +60,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { @@ -77,7 +77,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } #undef TF_LITE_LOCAL_RESPONSE_NORM } else { - context->ReportError(context, "Inputs and outputs not all float types."); + context->ReportError(context, "Output type is %d, requires float.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc index 0ee35775d50b8750455572f789d7b92481655a95..25d2dc2cdd699b4d9c8e83eb848fce0df3c59c15 100644 --- a/tensorflow/contrib/lite/kernels/lsh_projection.cc +++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc @@ -77,16 +77,16 @@ TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* hash = GetInput(context, node, 0); + const TfLiteTensor* hash = GetInput(context, node, 0); TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2); // Support up to 32 bits. TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32); - TfLiteTensor* input = GetInput(context, node, 1); + const TfLiteTensor* input = GetInput(context, node, 1); TF_LITE_ENSURE(context, NumDimensions(input) >= 1); if (NumInputs(node) == 3) { - TfLiteTensor* weight = GetInput(context, node, 2); + const TfLiteTensor* weight = GetInput(context, node, 2); TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1); TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0), SizeOfDimension(input, 0)); @@ -173,9 +173,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { reinterpret_cast(node->builtin_data); int32_t* out_buf = GetOutput(context, node, 0)->data.i32; - TfLiteTensor* hash = GetInput(context, node, 0); - TfLiteTensor* input = GetInput(context, node, 1); - TfLiteTensor* weight = + const TfLiteTensor* hash = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 1); + const TfLiteTensor* weight = NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2); switch (params->type) { diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index a1521efbb4e2dfc378915fb04b0cd156353bcc5e..eb26a02455ce2afccaa081a72d93a9ceeca746cc 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" @@ -34,6 +36,17 @@ namespace ops { namespace builtin { namespace lstm { +struct OpData { + // Which kernel type to use. Full kernel (18-inputs) or basic kernel + // (5-inputs). + TfLiteLSTMKernelType kernel_type; + // Only used by full kernel. + int scratch_tensor_index; +}; + +// For full inputs kernel (18-inputs). +namespace full { + // Input Tensors of size {n_batch, n_input} constexpr int kInputTensor = 0; @@ -71,20 +84,18 @@ constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); - return scratch_tensor_index; -} - -void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + auto* op_data = new OpData; + op_data->kernel_type = kTfLiteLSTMFullKernel; + context->AddTensors(context, /*tensors_to_add=*/7, + &op_data->scratch_tensor_index); + return op_data; } // Check that input tensor dimensions matches with each other. TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - auto* params = reinterpret_cast(node->builtin_data); + const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -92,29 +103,29 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE(context, params->cell_clip >= 0); TF_LITE_ENSURE(context, params->proj_clip >= 0); - TfLiteTensor* input_to_input_weights = + const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - if (input_to_input_weights) { + if (input_to_input_weights != nullptr) { TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); } - TfLiteTensor* input_to_forget_weights = + const TfLiteTensor* input_to_forget_weights = GetInput(context, node, kInputToForgetWeightsTensor); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); - TfLiteTensor* input_to_cell_weights = + const TfLiteTensor* input_to_cell_weights = GetInput(context, node, kInputToCellWeightsTensor); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); - TfLiteTensor* recurrent_to_input_weights = + const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - if (recurrent_to_input_weights) { + if (recurrent_to_input_weights != nullptr) { TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], n_cell); @@ -122,7 +133,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, n_output); } - TfLiteTensor* recurrent_to_forget_weights = + const TfLiteTensor* recurrent_to_forget_weights = GetInput(context, node, kRecurrentToForgetWeightsTensor); TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], @@ -130,7 +141,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], n_output); - TfLiteTensor* recurrent_to_cell_weights = + const TfLiteTensor* recurrent_to_cell_weights = GetInput(context, node, kRecurrentToCellWeightsTensor); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); @@ -146,21 +157,21 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, (recurrent_to_input_weights == nullptr)); TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); - TfLiteTensor* cell_to_input_weights = + const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); if (cell_to_input_weights) { TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); } - TfLiteTensor* cell_to_forget_weights = + const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); if (cell_to_forget_weights) { TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); } - TfLiteTensor* cell_to_output_weights = + const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); if (cell_to_output_weights) { TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); @@ -179,7 +190,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); // Make sure the input gate bias is present only when not a CIFG-LSTM. - TfLiteTensor* input_gate_bias = + const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(context, node, kInputGateBiasTensor); if (use_cifg) { TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); @@ -188,31 +199,31 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); } - TfLiteTensor* forget_gate_bias = + const TfLiteTensor* forget_gate_bias = GetInput(context, node, kForgetGateBiasTensor); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); - TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); - TfLiteTensor* output_gate_bias = + const TfLiteTensor* output_gate_bias = GetInput(context, node, kOutputGateBiasTensor); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); - TfLiteTensor* projection_weights = + const TfLiteTensor* projection_weights = GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - if (projection_weights) { + if (projection_weights != nullptr) { TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); } - TfLiteTensor* projection_bias = + const TfLiteTensor* projection_bias = GetOptionalInputTensor(context, node, kProjectionBiasTensor); - if (projection_bias) { + if (projection_bias != nullptr) { TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); } @@ -233,7 +244,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, // Allocate a temporary scratch tensor. Also check that the sizes of the input // tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); + OpData* op_data = reinterpret_cast(node->user_data); // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); @@ -241,18 +252,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and number of cells from the // input tensors. - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE(context, input->dims->size > 1); const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; - TfLiteTensor* input_to_output_weights = + const TfLiteTensor* input_to_output_weights = GetInput(context, node, kInputToOutputWeightsTensor); const int n_cell = input_to_output_weights->dims->data[0]; TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); - TfLiteTensor* recurrent_to_output_weights = + const TfLiteTensor* recurrent_to_output_weights = GetInput(context, node, kRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], @@ -286,86 +298,148 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, cell_state, cell_size)); - // Create a scratch buffer tensor. + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + // The weights are of consistent type, so it suffices to check one. + // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. + const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && + input->type == kTfLiteFloat32); + TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); - node->temporaries->data[0] = *scratch_tensor_index; + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(7); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } + node->temporaries->data[0] = op_data->scratch_tensor_index; + + // Create a scratch buffer tensor. TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - - TfLiteTensor* input_to_input_weights = + const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); const bool use_cifg = (input_to_input_weights == nullptr); + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; if (use_cifg) { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 3; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); } else { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Input, Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 4; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // output_state and cell_state tensors. + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, /*index=*/2); + output_state_quantized->type = kTfLiteUInt8; + output_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(output_state_quantized->dims, + output_state->dims)) { + TfLiteIntArray* output_state_quantized_size = + TfLiteIntArrayCopy(output_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output_state_quantized, + output_state_quantized_size)); + } + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + cell_state_quantized->type = kTfLiteUInt8; + cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { + TfLiteIntArray* cell_state_quantized_size = + TfLiteIntArrayCopy(cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state_quantized, + cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[4] = op_data->scratch_tensor_index + 4; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[5] = op_data->scratch_tensor_index + 5; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[6] = op_data->scratch_tensor_index + 6; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } } return kTfLiteOk; } // The LSTM Op engine. -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - TfLiteTensor* input = GetInput(context, node, kInputTensor); - - TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - TfLiteTensor* input_to_forget_weights = - GetInput(context, node, kInputToForgetWeightsTensor); - TfLiteTensor* input_to_cell_weights = - GetInput(context, node, kInputToCellWeightsTensor); - TfLiteTensor* input_to_output_weights = - GetInput(context, node, kInputToOutputWeightsTensor); - - TfLiteTensor* recurrent_to_input_weights = - GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - TfLiteTensor* recurrent_to_forget_weights = - GetInput(context, node, kRecurrentToForgetWeightsTensor); - TfLiteTensor* recurrent_to_cell_weights = - GetInput(context, node, kRecurrentToCellWeightsTensor); - TfLiteTensor* recurrent_to_output_weights = - GetInput(context, node, kRecurrentToOutputWeightsTensor); - - TfLiteTensor* cell_to_input_weights = - GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - TfLiteTensor* cell_to_forget_weights = - GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - TfLiteTensor* cell_to_output_weights = - GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - - TfLiteTensor* input_gate_bias = - GetOptionalInputTensor(context, node, kInputGateBiasTensor); - TfLiteTensor* forget_gate_bias = - GetInput(context, node, kForgetGateBiasTensor); - TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - TfLiteTensor* output_gate_bias = - GetInput(context, node, kOutputGateBiasTensor); - - TfLiteTensor* projection_weights = - GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - TfLiteTensor* projection_bias = - GetOptionalInputTensor(context, node, kProjectionBiasTensor); - - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; // n_cell and n_output will be the same size when there is no projection. @@ -377,9 +451,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); - // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); - float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; float* forget_gate_scratch = nullptr; @@ -447,6 +518,421 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, + TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast(projection_weights->data.uint8); + const float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_ptr_batch = input->data.f; + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + float* output_ptr_batch = output->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_output_state_ptr = + reinterpret_cast(output_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, + projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_cell_weights_ptr, quantized_input_ptr, + quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr, + cell_state_ptr, output_ptr_batch); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* params = reinterpret_cast(node->builtin_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(mirkov): add a check that weights are all uint8s or all floats. + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, + scratch_buffer, output_state, cell_state, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, /*index=*/2); + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + return EvalHybrid( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, output_state_quantized, cell_state_quantized, + output_state, cell_state, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + input_to_output_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace full + +// For basic kernel (5-inputs). +namespace basic { + +enum InputTensor { + kInputData = 0, + kInputPrevActivation = 1, + kInputWeights = 2, + kInputBiases = 3, + kInputPrevState = 4, + kInputNum = 5, +}; + +enum OutputTensor { + kOutputActivation = 0, + kOutputState = 1, + kOutputConcatTemp = 2, + kOutputActivationTemp = 3, + kOutputNum = 4, +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + op_data->kernel_type = kTfLiteLSTMBasicKernel; + // `scratch_tensor_index` is unused in this kernel. + op_data->scratch_tensor_index = -1; + return op_data; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, node->inputs->size == kInputNum); + TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); + + // Only Float32 is supported currently. + // TODO(ycling): Implement quantize uint8 support. + for (int index = 0; index < node->inputs->size; ++index) { + TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; + TF_LITE_ENSURE_EQ(context, tensor->type, kTfLiteFloat32); + } + + const TfLiteTensor* input = GetInput(context, node, kInputData); + const TfLiteTensor* prev_activation = + GetInput(context, node, kInputPrevActivation); + const TfLiteTensor* weights = GetInput(context, node, kInputWeights); + const TfLiteTensor* bias = GetInput(context, node, kInputBiases); + const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); + + TF_LITE_ENSURE_EQ(context, input->dims->size, 2); + const int num_batches = input->dims->data[0]; + + TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2); + TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches); + + TF_LITE_ENSURE_EQ(context, weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); + + TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2); + TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches); + + TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); + TfLiteTensor* state_out = GetOutput(context, node, kOutputState); + TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); + TfLiteTensor* activation_temp = + GetOutput(context, node, kOutputActivationTemp); + + TF_LITE_ENSURE_OK(context, context->ResizeTensor( + context, activation_out, + TfLiteIntArrayCopy(prev_activation->dims))); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, state_out, + TfLiteIntArrayCopy(prev_state->dims))); + TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2); + concat_temp_size->data[0] = num_batches; + concat_temp_size->data[1] = weights->dims->data[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, concat_temp, concat_temp_size)); + TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2); + activation_temp_size->data[0] = num_batches; + activation_temp_size->data[1] = weights->dims->data[0]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp, + activation_temp_size)); + + // Set the state tensors as persistent. + for (auto index : {kInputPrevActivation, kInputPrevState}) { + TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; + tensor->allocation_type = kTfLiteArenaRwPersistent; + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputData); + const TfLiteTensor* prev_activation = + GetInput(context, node, kInputPrevActivation); + const TfLiteTensor* weights = GetInput(context, node, kInputWeights); + const TfLiteTensor* bias = GetInput(context, node, kInputBiases); + const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); + + TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); + TfLiteTensor* state_out = GetOutput(context, node, kOutputState); + TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); + TfLiteTensor* activation_temp = + GetOutput(context, node, kOutputActivationTemp); + + optimized_ops::LstmCell( + // Inputs. + GetTensorData(input), GetTensorDims(input), + GetTensorData(prev_activation), GetTensorDims(prev_activation), + GetTensorData(weights), GetTensorDims(weights), + GetTensorData(bias), GetTensorDims(bias), + GetTensorData(prev_state), GetTensorDims(prev_state), + // Outputs. + GetTensorData(state_out), GetTensorDims(state_out), + GetTensorData(activation_out), GetTensorDims(activation_out), + GetTensorData(concat_temp), GetTensorDims(concat_temp), + GetTensorData(activation_temp), GetTensorDims(activation_temp)); + + // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs + // LSTM kernel. + memcpy(prev_activation->data.raw, activation_out->data.raw, + activation_out->bytes); + memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes); + + return kTfLiteOk; +} + +} // namespace basic + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + const auto* params = reinterpret_cast(buffer); + switch (params->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Init(context, buffer, length); + case kTfLiteLSTMBasicKernel: + return basic::Init(context, buffer, length); + } +} +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + switch (op_data->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Prepare(context, node); + case kTfLiteLSTMBasicKernel: + return basic::Prepare(context, node); + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + switch (op_data->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Eval(context, node); + case kTfLiteLSTMBasicKernel: + return basic::Eval(context, node); + } +} + } // namespace lstm TfLiteRegistration* Register_LSTM() { diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index d81220d8d30793616444c03e8647b0877a39a4d9..6da29a4a923f16f7b5ad382f51cfd820783504cd 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite LSTM op. -#include #include #include @@ -35,7 +34,8 @@ class LSTMOpModel : public SingleOpModel { LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, - const std::vector>& input_shapes) + const std::vector>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -45,31 +45,31 @@ class LSTMOpModel : public SingleOpModel { if (use_cifg) { input_to_input_weights_ = AddNullInput(); } else { - input_to_input_weights_ = AddInput(TensorType_FLOAT32); + input_to_input_weights_ = AddInput(weight_type); } - input_to_forget_weights_ = AddInput(TensorType_FLOAT32); - input_to_cell_weights_ = AddInput(TensorType_FLOAT32); - input_to_output_weights_ = AddInput(TensorType_FLOAT32); + input_to_forget_weights_ = AddInput(weight_type); + input_to_cell_weights_ = AddInput(weight_type); + input_to_output_weights_ = AddInput(weight_type); if (use_cifg) { recurrent_to_input_weights_ = AddNullInput(); } else { - recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_input_weights_ = AddInput(weight_type); } - recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_forget_weights_ = AddInput(weight_type); + recurrent_to_cell_weights_ = AddInput(weight_type); + recurrent_to_output_weights_ = AddInput(weight_type); if (use_peephole) { if (use_cifg) { cell_to_input_weights_ = AddNullInput(); } else { - cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + cell_to_input_weights_ = AddInput(weight_type); } - cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); - cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + cell_to_forget_weights_ = AddInput(weight_type); + cell_to_output_weights_ = AddInput(weight_type); } else { cell_to_input_weights_ = AddNullInput(); cell_to_forget_weights_ = AddNullInput(); @@ -86,7 +86,7 @@ class LSTMOpModel : public SingleOpModel { output_gate_bias_ = AddInput(TensorType_FLOAT32); if (use_projection_weights) { - projection_weights_ = AddInput(TensorType_FLOAT32); + projection_weights_ = AddInput(weight_type); if (use_projection_bias) { projection_bias_ = AddInput(TensorType_FLOAT32); } else { @@ -192,8 +192,9 @@ class LSTMOpModel : public SingleOpModel { zero_buffer.get() + zero_buffer_size); } - void SetInput(int offset, float* begin, float* end) { - PopulateTensor(input_, offset, begin, end); + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); } std::vector GetOutput() { return ExtractVector(output_); } @@ -203,7 +204,7 @@ class LSTMOpModel : public SingleOpModel { int num_cells() { return n_cell_; } int num_batches() { return n_batch_; } - private: + protected: int input_; int input_to_input_weights_; int input_to_forget_weights_; @@ -237,7 +238,182 @@ class LSTMOpModel : public SingleOpModel { int n_output_; }; -TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { +class HybridLSTMOpModel : public LSTMOpModel { + public: + HybridLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, + float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole, + use_projection_weights, use_projection_bias, cell_clip, + proj_clip, input_shapes, TensorType_UINT8) {} + + void SetInputToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(projection_weights_, f); + } +}; + +class BaseLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list projection_weights_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + LSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end); + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + std::vector expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + for (int i = 0; i < num_outputs; ++i) { + std::cout << lstm->GetOutput()[i] << ", "; + } + std::cout << std::endl; + for (int i = 0; i < num_outputs; ++i) { + std::cout << expected[i] << ", "; + } + std::cout << std::endl; + } + } +}; + +class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, + -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}}; + } +}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -257,10 +433,10 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {n_cell, n_input}, // input_to_cell_weight tensor {n_cell, n_input}, // input_to_output_weight tensor - {n_cell, n_output}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor + {n_cell, n_output}, // recurrent_to_input_weight_tensor + {n_cell, n_output}, // recurrent_to_forget_weight_tensor + {n_cell, n_output}, // recurrent_to_cell_weight_tensor + {n_cell, n_output}, // recurrent_to_output_weight_tensor {0}, // cell_to_input_weight tensor {0}, // cell_to_forget_weight tensor @@ -275,79 +451,137 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, - -0.34550029, 0.04266912, -0.15680569, - -0.34856534, 0.43890524}); - - lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, - -0.20583314, 0.44344562, 0.22077113, - -0.29909778}); - - lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, - -0.31343272, -0.40032279, 0.44781327, - 0.01387155, -0.35593212}); - - lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, - 0.40525138, 0.44272184, 0.03897077, -0.1556896, - 0.19487578}); + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetInputGateBias({0., 0., 0., 0.}); + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetCellBias({0., 0., 0., 0.}); + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - lstm.SetForgetGateBias({1., 1., 1., 1.}); - - lstm.SetOutputGateBias({0., 0., 0., 0.}); - - lstm.SetRecurrentToInputWeights( - {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, - -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, - -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); - - lstm.SetRecurrentToCellWeights( - {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, - -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, - -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - lstm.SetRecurrentToForgetWeights( - {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, - -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, - 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetRecurrentToOutputWeights( - {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, - 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, - -0.51818722, -0.15390486, 0.0468148, 0.39922136}); +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, - -0.15358765, -0.03716109, 0.12507336, - 0.41193449, -0.20860538, -0.15053082, - 0.09120187, 0.24278517, -0.12222792}; + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, + /*tolerance=*/0.0157651); +} - lstm.SetInput(0, batch0_start, batch0_end); +class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; - lstm.Invoke(); + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}}; } -} +}; -TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -385,74 +619,689 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, - 0.04717243, 0.48944736, -0.38535351, - -0.17212132}); - - lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, - -0.3633365, -0.22755712, 0.28253698, 0.24407166, - 0.33826375}); - - lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, - -0.09426838, -0.44257352, 0.54939759, - 0.01533556, 0.42751634}); - - lstm.SetCellBias({0., 0., 0., 0.}); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetForgetGateBias({1., 1., 1., 1.}); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetOutputGateBias({0., 0., 0., 0.}); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - lstm.SetRecurrentToCellWeights( - {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, - 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, - 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, - 0.21193194}); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); - lstm.SetRecurrentToForgetWeights( - {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, - 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, - -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - lstm.SetRecurrentToOutputWeights( - {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, - -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, - 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetCellToForgetWeights( - {0.47485286, -0.51955009, -0.24458408, 0.31544167}); - lstm.SetCellToOutputWeights( - {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, - -0.05163646, -0.42312205, -0.01218222, - 0.24201041, -0.08124574, -0.358325, - -0.04621704, 0.21641694, -0.06471302}; + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); - - lstm.SetInput(0, batch0_start, batch0_end); - - lstm.Invoke(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); +} - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, + 0.028273236, -0.016726194, -0.05249759, -0.10204261, + 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, + -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, + 0.045630626, 0.06843887, -0.13492945, -0.012480007, + -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, + 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, + -0.06277783, -0.08833636, -0.040120605, -0.011405586, + -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, + 0.13396506, -0.08402166, -0.01901462, -0.044678304, + -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, + 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, + -0.0011980024, -0.034641717, -0.026125094, -0.17582615, + -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, + -0.01255415, -0.0026479573, -0.08196161, -0.054914974, + -0.0046604523, -0.029587349, -0.044576716, -0.07480124, + -0.082868785, 0.023254942, 0.027502948, -0.0039728214, + -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, + -0.08829125, -0.005139627, -0.08989442, -0.0555066, + 0.13596267, -0.025062224, -0.048351806, -0.03850004, + 0.07266485, -0.022414139, 0.05940088, 0.075114764, + 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, + 0.014416728, 0.043229222, 0.034178585, -0.07530371, + 0.035837382, -0.085607, -0.007721233, -0.03287832, + -0.043848954, -0.06404588, -0.06632928, -0.073643476, + 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, + -0.030063879, 0.008801774, -0.023021035, -0.019558564, + 0.05158114, -0.010947698, -0.011825728, 0.0075720972, + 0.0699727, -0.0039981045, 0.069350146, 0.08799282, + 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, + -0.00924166, 0.0046702605, -0.036598757, -0.08811812, + 0.10522024, -0.032441203, 0.008176899, -0.04454919, + 0.07058152, 0.0067963637, 0.039206743, 0.03259838, + 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, + 0.036879618, 0.043357447, 0.028362012, -0.05908629, + 0.0059240665, -0.04995891, -0.019187413, 0.0276265, + -0.01628143, 0.0025863599, 0.08800015, 0.035250366, + -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, + -0.009660886, 0.019076364, 0.018299393, -0.046004917, + 0.08891175, 0.0431396, -0.026327137, -0.051502608, + 0.08979574, -0.051670972, 0.04940282, -0.07491107, + -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, + -0.035936575, -0.011681591, 0.064818054, 0.0073146066, + -0.021745546, -0.043124277, -0.06471268, -0.07053354, + -0.029321948, -0.05330136, 0.016933719, -0.053782392, + 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, + -0.07924483, 0.06936997, 0.0034815092, -0.007305279, + -0.037325785, -0.07251102, -0.033633437, -0.08677009, + 0.091591336, -0.14165086, 0.021752775, 0.019683983, + 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, + 0.1183656, -0.0010731248, -0.023590032, -0.072285876, + -0.0724771, -0.026382286, -0.0014920527, 0.042667855, + 0.0018776858, 0.02986552, 0.009814309, 0.0733756, + 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, + -0.010036754, 0.02576849, -0.08307328, 0.010112348, + 0.042521734, -0.05869831, -0.071689695, 0.03876447, + -0.13275425, -0.0352966, -0.023077697, 0.10285965, + 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, + -0.08271222, -0.0030240538, -0.016368777, 0.1070414, + 0.042672627, 0.013456989, -0.0437609, -0.022309763, + 0.11576483, 0.04108048, 0.061026827, -0.0190714, + -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, + -0.023771819, -0.01965048, 0.007955471, -0.043740474, + 0.03346837, -0.10549954, 0.090567775, 0.042013682, + -0.03176985, 0.12569028, -0.02421228, -0.029526481, + 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, + -0.06861939, -0.021256343, -0.041093912, -0.06669611, + 0.035498552, 0.021757556, -0.09302526, -0.015403468, + -0.06614931, -0.051798206, -0.013874718, 0.03630673, + 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, + -0.020674974, -0.03944324, -0.008110165, -0.11113267, + 0.08484226, 0.043586485, 0.040582247, 0.0968012, + -0.065249965, -0.028036479, 0.0050708856, 0.0017462453, + 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, + -0.11768019, 0.085926116, -0.08251791, -0.045081906, + 0.0948852, 0.068401024, 0.024856757, 0.06978981, + -0.057309967, -0.012775832, -0.0032452994, 0.01977615, + -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = { + 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, + 0.09641832, 0.060420845, 0.08539281, 0.054285463, + 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, + -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, + 0.08633481, -0.28869787, 0.08682067, 0.17240396, + 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, + 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, + 0.13221376, 0.009441198, -0.16715929, 0.15859416, + -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, + 0.0046453946, 0.050794356, 0.10770313, -0.20790008, + -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, + -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, + -0.12035478, -0.08361797, -0.050666027, -0.1248618, + -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, + 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, + 0.13708383, 0.09134148, -0.12617786, -0.06428341, + 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, + 0.022913318, -0.042050496, 0.16842307, -0.060597885, + 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, + 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, + -0.02328883, 0.009533515, -0.03606021, -0.07421458, + -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, + 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, + 0.1776665, -0.07409566, -0.0477268, 0.29323658, + 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, + 0.10034707, 0.045594677, 0.0635285, -0.0715442, + -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, + 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, + 0.09663576, -0.057112187, -0.10100678, 0.0628376, + 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, + -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, + 0.12633002, 0.11335965, -0.0088127935, -0.019777594, + 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, + -0.07468996, -0.0855457, 0.099339016, -0.07580735, + -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, + 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, + 0.09793732, 0.07512415, -0.11319543, -0.032502122, + 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, + -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, + 0.11156489, -0.07483202, 0.06693364, -0.26151478, + 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, + 0.030635102, 0.010969227, 0.11109743, 0.010919218, + 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, + -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0 + 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1 + 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2 + 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3 + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0 + 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1 + 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2 + 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3 + }; + + lstm_golden_output_ = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; } -} +}; -TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -489,588 +1338,98 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights( - {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, - 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, - -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, - -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, - -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, - -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, - -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, - 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, - 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, - 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, - -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, - 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, - -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, - -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, - -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, - 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, - -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, - -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, - -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, - -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); - - lstm.SetInputToForgetWeights( - {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, - -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, - -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, - 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, - 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, - -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, - -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, - 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, - 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, - 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, - 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, - -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, - 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, - -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, - -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, - 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, - 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, - 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, - -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, - 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); - - lstm.SetInputToCellWeights( - {-0.04580283, -0.09549462, -0.032418985, -0.06454633, - -0.043528453, 0.043018587, -0.049152344, -0.12418144, - -0.078985475, -0.07596889, 0.019484362, -0.11434962, - -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, - -0.025034338, -0.0028890965, 0.048929527, 0.06235075, - 0.10665918, -0.032036792, -0.08505916, -0.10843358, - -0.13002433, -0.036816437, -0.02130134, -0.016518239, - 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, - -0.10652836, -0.1037554, -0.13056071, -0.03266643, - -0.033702414, -0.006473424, -0.04611692, 0.014419339, - -0.025174323, 0.0396852, 0.081777506, 0.06157468, - 0.10210095, -0.009658194, 0.046511717, 0.03603906, - 0.0069369148, 0.015960095, -0.06507666, 0.09551598, - 0.053568836, 0.06408714, 0.12835667, -0.008714329, - -0.20211966, -0.12093674, 0.029450472, 0.2849013, - -0.029227901, 0.1164364, -0.08560263, 0.09941786, - -0.036999565, -0.028842626, -0.0033637602, -0.017012902, - -0.09720865, -0.11193351, -0.029155117, -0.017936034, - -0.009768936, -0.04223324, -0.036159635, 0.06505112, - -0.021742892, -0.023377212, -0.07221364, -0.06430552, - 0.05453865, 0.091149814, 0.06387331, 0.007518393, - 0.055960953, 0.069779344, 0.046411168, 0.10509911, - 0.07463894, 0.0075130584, 0.012850982, 0.04555431, - 0.056955688, 0.06555285, 0.050801456, -0.009862683, - 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); - - lstm.SetInputToOutputWeights( - {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, - -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, - 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, - -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, - -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, - 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, - -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, - -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, - -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, - -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, - 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, - 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, - 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, - -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, - 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, - 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, - -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, - 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, - -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, - -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); - - lstm.SetInputGateBias( - {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, - -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, - -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, - 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); - - lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, - 0.11098921, 0.15378423, 0.09263801, 0.09790885, - 0.09508917, 0.061199076, 0.07665568, -0.015443159, - -0.03499149, 0.046190713, 0.08895977, 0.10899629, - 0.40694186, 0.06030037, 0.012413437, -0.06108739}); - - lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, - -0.1483596, -0.10639995, -0.091433935, 0.058573797, - -0.06809782, -0.07889636, -0.043246906, -0.09829136, - -0.4279842, 0.034901652, 0.18797937, 0.0075234566, - 0.016178843, 0.1749513, 0.13975595, 0.92058027}); - - lstm.SetOutputGateBias( - {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, - 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, - 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, - -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); - - lstm.SetRecurrentToInputWeights( - {-0.001374326, -0.078856036, 0.10672688, 0.029162422, - -0.11585556, 0.02557986, -0.13446963, -0.035785314, - -0.01244275, 0.025961924, -0.02337298, -0.044228926, - -0.055839065, -0.046598054, -0.010546039, -0.06900766, - 0.027239809, 0.022582639, -0.013296484, -0.05459212, - 0.08981, -0.045407712, 0.08682226, -0.06867011, - -0.14390695, -0.02916037, 0.000996957, 0.091420636, - 0.14283475, -0.07390571, -0.06402044, 0.062524505, - -0.093129106, 0.04860203, -0.08364217, -0.08119002, - 0.009352075, 0.22920375, 0.0016303885, 0.11583097, - -0.13732095, 0.012405723, -0.07551853, 0.06343048, - 0.12162708, -0.031923793, -0.014335606, 0.01790974, - -0.10650317, -0.0724401, 0.08554849, -0.05727212, - 0.06556731, -0.042729504, -0.043227166, 0.011683251, - -0.013082158, -0.029302018, -0.010899579, -0.062036745, - -0.022509435, -0.00964907, -0.01567329, 0.04260106, - -0.07787477, -0.11576462, 0.017356863, 0.048673786, - -0.017577527, -0.05527947, -0.082487635, -0.040137455, - -0.10820036, -0.04666372, 0.022746278, -0.07851417, - 0.01068115, 0.032956902, 0.022433773, 0.0026891115, - 0.08944216, -0.0685835, 0.010513544, 0.07228705, - 0.02032331, -0.059686817, -0.0005566496, -0.086984694, - 0.040414046, -0.1380399, 0.094208956, -0.05722982, - 0.012092817, -0.04989123, -0.086576, -0.003399834, - -0.04696032, -0.045747425, 0.10091314, 0.048676282, - -0.029037097, 0.031399418, -0.0040285117, 0.047237843, - 0.09504992, 0.041799378, -0.049185462, -0.031518843, - -0.10516937, 0.026374253, 0.10058866, -0.0033195973, - -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, - -0.10167381, 0.042500053, -0.01447153, 0.06464186, - -0.017142897, 0.03312627, 0.009205989, 0.024138335, - -0.011337001, 0.035530265, -0.010912711, 0.0706555, - -0.005894094, 0.051841937, -0.1401738, -0.02351249, - 0.0365468, 0.07590991, 0.08838724, 0.021681072, - -0.10086113, 0.019608743, -0.06195883, 0.077335775, - 0.023646897, -0.095322326, 0.02233014, 0.09756986, - -0.048691444, -0.009579111, 0.07595467, 0.11480546, - -0.09801813, 0.019894179, 0.08502348, 0.004032281, - 0.037211012, 0.068537936, -0.048005626, -0.091520436, - -0.028379958, -0.01556313, 0.06554592, -0.045599163, - -0.01672207, -0.020169014, -0.011877351, -0.20212261, - 0.010889619, 0.0047078193, 0.038385306, 0.08540671, - -0.017140968, -0.0035865551, 0.016678626, 0.005633034, - 0.015963363, 0.00871737, 0.060130805, 0.028611384, - 0.10109069, -0.015060172, -0.07894427, 0.06401885, - 0.011584063, -0.024466386, 0.0047652307, -0.09041358, - 0.030737216, -0.0046374933, 0.14215417, -0.11823516, - 0.019899689, 0.006106124, -0.027092824, 0.0786356, - 0.05052217, -0.058925, -0.011402121, -0.024987547, - -0.0013661642, -0.06832946, -0.015667673, -0.1083353, - -0.00096863037, -0.06988685, -0.053350925, -0.027275559, - -0.033664223, -0.07978348, -0.025200296, -0.017207067, - -0.058403496, -0.055697463, 0.005798788, 0.12965427, - -0.062582195, 0.0013350133, -0.10482091, 0.0379771, - 0.072521195, -0.0029455067, -0.13797039, -0.03628521, - 0.013806405, -0.017858358, -0.01008298, -0.07700066, - -0.017081132, 0.019358726, 0.0027079724, 0.004635139, - 0.062634714, -0.02338735, -0.039547626, -0.02050681, - 0.03385117, -0.083611414, 0.002862572, -0.09421313, - 0.058618143, -0.08598433, 0.00972939, 0.023867095, - -0.053934585, -0.023203006, 0.07452513, -0.048767887, - -0.07314807, -0.056307215, -0.10433547, -0.06440842, - 0.04328182, 0.04389765, -0.020006588, -0.09076438, - -0.11652589, -0.021705797, 0.03345259, -0.010329105, - -0.025767034, 0.013057034, -0.07316461, -0.10145612, - 0.06358255, 0.18531723, 0.07759293, 0.12006465, - 0.1305557, 0.058638252, -0.03393652, 0.09622831, - -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, - -0.005644518, 0.06857898, -0.12598175, -0.035084512, - 0.03156317, -0.12794146, -0.031963028, 0.04692781, - 0.030070418, 0.0071660685, -0.095516115, -0.004643372, - 0.040170413, -0.062104587, -0.0037324072, 0.0554317, - 0.08184801, -0.019164372, 0.06791302, 0.034257166, - -0.10307039, 0.021943003, 0.046745934, 0.0790918, - -0.0265588, -0.007824208, 0.042546265, -0.00977924, - -0.0002440307, -0.017384544, -0.017990116, 0.12252321, - -0.014512694, -0.08251313, 0.08861942, 0.13589665, - 0.026351685, 0.012641483, 0.07466548, 0.044301085, - -0.045414884, -0.051112458, 0.03444247, -0.08502782, - -0.04106223, -0.028126027, 0.028473156, 0.10467447}); - - lstm.SetRecurrentToForgetWeights( - {-0.057784554, -0.026057621, -0.068447545, -0.022581743, - 0.14811787, 0.10826372, 0.09471067, 0.03987225, - -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, - 0.08414449, -0.022036452, -0.00066928595, -0.09203576, - 0.032950465, -0.10985798, -0.023809856, 0.0021431844, - -0.02196096, -0.00326074, 0.00058621005, -0.074678116, - -0.06193199, 0.055729095, 0.03736828, 0.020123724, - 0.061878487, -0.04729229, 0.034919553, -0.07585433, - -0.04421272, -0.044019096, 0.085488975, 0.04058006, - -0.06890133, -0.030951202, -0.024628663, -0.07672815, - 0.034293607, 0.08556707, -0.05293577, -0.033561368, - -0.04899627, 0.0241671, 0.015736353, -0.095442444, - -0.029564252, 0.016493602, -0.035026584, 0.022337519, - -0.026871363, 0.004780428, 0.0077918363, -0.03601621, - 0.016435321, -0.03263031, -0.09543275, -0.047392778, - 0.013454138, 0.028934088, 0.01685226, -0.086110644, - -0.046250615, -0.01847454, 0.047608484, 0.07339695, - 0.034546845, -0.04881143, 0.009128804, -0.08802852, - 0.03761666, 0.008096139, -0.014454086, 0.014361001, - -0.023502491, -0.0011840804, -0.07607001, 0.001856849, - -0.06509276, -0.006021153, -0.08570962, -0.1451793, - 0.060212336, 0.055259194, 0.06974018, 0.049454916, - -0.027794661, -0.08077226, -0.016179763, 0.1169753, - 0.17213494, -0.0056326236, -0.053934924, -0.0124349, - -0.11520337, 0.05409887, 0.088759385, 0.0019655675, - 0.0042065294, 0.03881498, 0.019844765, 0.041858196, - -0.05695512, 0.047233116, 0.038937137, -0.06542224, - 0.014429736, -0.09719407, 0.13908425, -0.05379757, - 0.012321099, 0.082840554, -0.029899208, 0.044217527, - 0.059855383, 0.07711018, -0.045319796, 0.0948846, - -0.011724666, -0.0033288454, -0.033542685, -0.04764985, - -0.13873616, 0.040668588, 0.034832682, -0.015319203, - -0.018715994, 0.046002675, 0.0599172, -0.043107376, - 0.0294216, -0.002314414, -0.022424703, 0.0030315618, - 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, - 0.12375372, -0.0006038222, 0.029104086, 0.087442465, - 0.052958444, 0.07558703, 0.04817258, 0.044462286, - -0.015213451, -0.08783778, -0.0561384, -0.003008196, - 0.047060397, -0.002058388, 0.03429439, -0.018839769, - 0.024734668, 0.024614193, -0.042046934, 0.09597743, - -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, - -0.02558259, -0.022822596, -0.023273505, -0.02464396, - -0.10991725, -0.006240552, 0.0074488563, 0.024044557, - 0.04383914, -0.046476185, 0.028658995, 0.060410924, - 0.050786525, 0.009452605, -0.0073054377, -0.024810238, - 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, - 0.015898481, 0.021362653, -0.030262267, 0.016587038, - -0.011442813, 0.041154444, -0.007631438, -0.03423484, - -0.010977775, 0.036152758, 0.0066366293, 0.11915515, - 0.02318443, -0.041350313, 0.021485701, -0.10906167, - -0.028218046, -0.00954771, 0.020531068, -0.11995105, - -0.03672871, 0.024019798, 0.014255957, -0.05221243, - -0.00661567, -0.04630967, 0.033188973, 0.10107534, - -0.014027541, 0.030796422, -0.10270911, -0.035999842, - 0.15443139, 0.07684145, 0.036571592, -0.035900835, - -0.0034699554, 0.06209149, 0.015920248, -0.031122351, - -0.03858649, 0.01849943, 0.13872518, 0.01503974, - 0.069941424, -0.06948533, -0.0088794185, 0.061282158, - -0.047401894, 0.03100163, -0.041533746, -0.10430945, - 0.044574402, -0.01425562, -0.024290353, 0.034563623, - 0.05866852, 0.023947537, -0.09445152, 0.035450947, - 0.02247216, -0.0042998926, 0.061146557, -0.10250651, - 0.020881841, -0.06747029, 0.10062043, -0.0023941975, - 0.03532124, -0.016341697, 0.09685456, -0.016764693, - 0.051808182, 0.05875331, -0.04536488, 0.001626336, - -0.028892258, -0.01048663, -0.009793449, -0.017093895, - 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, - -0.001845119, -0.03551521, 0.0018358806, 0.05763657, - -0.01769146, 0.040995963, 0.02235177, -0.060430344, - 0.11475477, -0.023854522, 0.10071741, 0.0686208, - -0.014250481, 0.034261297, 0.047418304, 0.08562733, - -0.030519066, 0.0060542435, 0.014653856, -0.038836084, - 0.04096551, 0.032249358, -0.08355519, -0.026823482, - 0.056386515, -0.010401743, -0.028396193, 0.08507674, - 0.014410365, 0.020995233, 0.17040324, 0.11511526, - 0.02459721, 0.0066619175, 0.025853224, -0.023133837, - -0.081302024, 0.017264642, -0.009585969, 0.09491168, - -0.051313367, 0.054532815, -0.014298593, 0.10657464, - 0.007076659, 0.10964551, 0.0409152, 0.008275321, - -0.07283536, 0.07937492, 0.04192024, -0.1075027}); - - lstm.SetRecurrentToCellWeights( - {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, - 0.055647098, -0.05713207, -0.05626563, 0.005559383, - 0.03375411, -0.025757805, -0.088049285, 0.06017052, - -0.06570978, 0.007384076, 0.035123326, -0.07920549, - 0.053676967, 0.044480428, -0.07663568, 0.0071805613, - 0.08089997, 0.05143358, 0.038261272, 0.03339287, - -0.027673481, 0.044746667, 0.028349208, 0.020090483, - -0.019443132, -0.030755889, -0.0040000007, 0.04465846, - -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, - -0.10893326, 0.076739706, -0.08509834, -0.027997585, - 0.037871376, 0.01449768, -0.09002357, -0.06111149, - -0.046195522, 0.0422062, -0.005683705, -0.1253618, - -0.012925729, -0.04890792, 0.06985068, 0.037654128, - 0.03398274, -0.004781977, 0.007032333, -0.031787455, - 0.010868644, -0.031489216, 0.09525667, 0.013939797, - 0.0058680447, 0.0167067, 0.02668468, -0.04797466, - -0.048885044, -0.12722108, 0.035304096, 0.06554885, - 0.00972396, -0.039238118, -0.05159735, -0.11329045, - 0.1613692, -0.03750952, 0.06529313, -0.071974665, - -0.11769596, 0.015524369, -0.0013754242, -0.12446318, - 0.02786344, -0.014179351, 0.005264273, 0.14376344, - 0.015983658, 0.03406988, -0.06939408, 0.040699873, - 0.02111075, 0.09669095, 0.041345075, -0.08316494, - -0.07684199, -0.045768797, 0.032298047, -0.041805092, - 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, - -0.024950314, 0.11574242, 0.04508852, -0.04335324, - 0.06760663, -0.027437469, 0.07216407, 0.06977076, - -0.05438599, 0.034033038, -0.028602652, 0.05346137, - 0.043184172, -0.037189785, 0.10420091, 0.00882477, - -0.054019816, -0.074273005, -0.030617684, -0.0028467078, - 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, - 0.04361412, -0.007001822, 0.09631092, -0.06702025, - -0.042049985, -0.035070654, -0.04103342, -0.10273396, - 0.0544271, 0.037184782, -0.13150354, -0.0058036847, - -0.008264958, 0.042035464, 0.05891794, 0.029673764, - 0.0063542654, 0.044788733, 0.054816857, 0.062257513, - -0.00093483756, 0.048938446, -0.004952862, -0.007730018, - -0.04043371, -0.017094059, 0.07229206, -0.023670016, - -0.052195564, -0.025616996, -0.01520939, 0.045104615, - -0.007376126, 0.003533447, 0.006570588, 0.056037236, - 0.12436656, 0.051817212, 0.028532185, -0.08686856, - 0.11868599, 0.07663395, -0.07323171, 0.03463402, - -0.050708205, -0.04458982, -0.11590894, 0.021273347, - 0.1251325, -0.15313013, -0.12224372, 0.17228661, - 0.023029093, 0.086124025, 0.006445803, -0.03496501, - 0.028332196, 0.04449512, -0.042436164, -0.026587414, - -0.006041347, -0.09292539, -0.05678812, 0.03897832, - 0.09465633, 0.008115513, -0.02171956, 0.08304309, - 0.071401566, 0.019622514, 0.032163795, -0.004167056, - 0.02295182, 0.030739572, 0.056506045, 0.004612461, - 0.06524936, 0.059999723, 0.046395954, -0.0045512207, - -0.1335546, -0.030136576, 0.11584653, -0.014678886, - 0.0020118146, -0.09688814, -0.0790206, 0.039770417, - -0.0329582, 0.07922767, 0.029322514, 0.026405897, - 0.04207835, -0.07073373, 0.063781224, 0.0859677, - -0.10925287, -0.07011058, 0.048005477, 0.03438226, - -0.09606514, -0.006669445, -0.043381985, 0.04240257, - -0.06955775, -0.06769346, 0.043903265, -0.026784198, - -0.017840602, 0.024307009, -0.040079936, -0.019946516, - 0.045318738, -0.12233574, 0.026170589, 0.0074471775, - 0.15978073, 0.10185836, 0.10298046, -0.015476589, - -0.039390966, -0.072174534, 0.0739445, -0.1211869, - -0.0347889, -0.07943156, 0.014809798, -0.12412325, - -0.0030663363, 0.039695457, 0.0647603, -0.08291318, - -0.018529687, -0.004423833, 0.0037507233, 0.084633216, - -0.01514876, -0.056505352, -0.012800942, -0.06994386, - 0.012962922, -0.031234352, 0.07029052, 0.016418684, - 0.03618972, 0.055686004, -0.08663945, -0.017404709, - -0.054761406, 0.029065743, 0.052404847, 0.020238016, - 0.0048197987, -0.0214882, 0.07078733, 0.013016777, - 0.06262858, 0.009184685, 0.020785125, -0.043904778, - -0.0270329, -0.03299152, -0.060088247, -0.015162964, - -0.001828936, 0.12642565, -0.056757294, 0.013586685, - 0.09232601, -0.035886683, 0.06000002, 0.05229691, - -0.052580316, -0.082029596, -0.010794592, 0.012947712, - -0.036429964, -0.085508935, -0.13127148, -0.017744139, - 0.031502828, 0.036232427, -0.031581745, 0.023051167, - -0.05325106, -0.03421577, 0.028793324, -0.034633752, - -0.009881397, -0.043551125, -0.018609839, 0.0019097115, - -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); - - lstm.SetRecurrentToOutputWeights({ - 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, - -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, - -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, - -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, - -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, - -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, - -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, - 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, - -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, - 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, - -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, - -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, - 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, - 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, - -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, - 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, - 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, - 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, - 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, - 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, - -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, - 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, - -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, - 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, - 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, - 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, - -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, - -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, - -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, - -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, - -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, - -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, - 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, - 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, - -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, - 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, - -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, - -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, - -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, - 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, - 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, - 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, - -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, - 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, - -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, - -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, - -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, - -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, - 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, - -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, - 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, - -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, - -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, - -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, - -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, - 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, - 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, - -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, - 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, - 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, - -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, - 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, - 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, - 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, - }); - - lstm.SetCellToInputWeights( - {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, - -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, - -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, - 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); - - lstm.SetCellToForgetWeights( - {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, - -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, - -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, - 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); - - lstm.SetCellToOutputWeights( - {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, - -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, - -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, - 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); - - lstm.SetProjectionWeights( - {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, - 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, - -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, - -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, - 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, - 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, - 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, - 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, - -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, - -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, - -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, - 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, - 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, - 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, - 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, - 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, - -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, - 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, - -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, - 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, - -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, - -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, - 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, - -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, - 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, - -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, - -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, - 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, - -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, - -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, - -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, - 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, - 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, - -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, - 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, - 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, - 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, - 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, - 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, - -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, - -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, - 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, - -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, - -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, - 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, - 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, - 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, - -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, - -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, - -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, - 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, - -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, - 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, - 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, - -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, - -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, - -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, - 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, - -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, - -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, - -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, - 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, - 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, - 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); - - static float lstm_input[][20] = { - {// Batch0: 4 (input_sequence_size) * 5 (n_input) - 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, - 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, - 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, - - {// Batch1: 4 (input_sequence_size) * 5 (n_input) - 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, - 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, - 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; - - static float lstm_golden_output[][64] = { - {// Batch0: 4 (input_sequence_size) * 16 (n_output) - -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, - -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, - -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, - 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, - -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, - -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, - 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, - 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, - 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, - 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, - -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, - -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, - 0.0286833, 0.00824207, 0.0264887, 0.0305169}, - {// Batch1: 4 (input_sequence_size) * 16 (n_output) - -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, - -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, - 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, - 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, - -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, - -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, - 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, - 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, - 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, - 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, - -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, - -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, - 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetInput(0, batch0_start, batch0_end); +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; - float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); - float* batch1_end = batch1_start + lstm.num_inputs(); - lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end); + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); - lstm.Invoke(); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); - float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); - float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); - float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); - expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc index 5a28d663c9e756040746f0a98b356afba76cceab..8d676218bdcf71a7acadf62f213d35c6997f7575 100644 --- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc +++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc @@ -41,8 +41,8 @@ struct OpContext { input2 = GetInput(context, node, kInputTensor2); output = GetOutput(context, node, kOutputTensor); } - TfLiteTensor* input1; - TfLiteTensor* input2; + const TfLiteTensor* input1; + const TfLiteTensor* input2; TfLiteTensor* output; }; diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/mean.cc index 98f80e32d95b47cbe6267457f9508e2897e804d9..03e5db24de3f3c2d4e17df21bc0b592a02078d6b 100644 --- a/tensorflow/contrib/lite/kernels/mean.cc +++ b/tensorflow/contrib/lite/kernels/mean.cc @@ -40,8 +40,8 @@ struct MeanContext { output = GetOutput(context, node, 0); } TfLiteMeanParams* params; - TfLiteTensor* input; - TfLiteTensor* axis; + const TfLiteTensor* input; + const TfLiteTensor* axis; TfLiteTensor* output; }; diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc index 018db0dc54c5d281bf3fb3ff8a1f111b427fe76b..3f5bc4d68a57daa8423953f591ac139dc55eacb9 100644 --- a/tensorflow/contrib/lite/kernels/mfcc.cc +++ b/tensorflow/contrib/lite/kernels/mfcc.cc @@ -67,8 +67,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* inputWav = GetInput(context, node, kInputTensorWav); - TfLiteTensor* inputRate = GetInput(context, node, kInputTensorRate); + const TfLiteTensor* inputWav = GetInput(context, node, kInputTensorWav); + const TfLiteTensor* inputRate = GetInput(context, node, kInputTensorRate); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, NumDimensions(inputWav), 3); @@ -94,8 +94,8 @@ template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->user_data); - TfLiteTensor* inputWav = GetInput(context, node, kInputTensorWav); - TfLiteTensor* inputRate = GetInput(context, node, kInputTensorRate); + const TfLiteTensor* inputWav = GetInput(context, node, kInputTensorWav); + const TfLiteTensor* inputRate = GetInput(context, node, kInputTensorRate); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const int32 sample_rate = *GetTensorData(inputRate); diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 54575019de4c678ce25561cf2ac8dc80c9973363..62f4e94a386fbbc6987e8a6dc1a9a47ce3349cbb 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -57,8 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, input1->type, input2->type); @@ -80,7 +80,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { template void EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params, const OpData* data, - TfLiteTensor* input1, TfLiteTensor* input2, + const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, @@ -109,7 +109,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, template void EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params, const OpData* data, - TfLiteTensor* input1, TfLiteTensor* input2, + const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { auto input1_offset = -input1->params.zero_point; auto input2_offset = -input2->params.zero_point; @@ -149,8 +149,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { @@ -159,8 +159,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { EvalQuantized(context, node, params, data, input1, input2, output); } else { - context->ReportError(context, - "Mul only supports FLOAT32 and quantized UINT8 now."); + context->ReportError( + context, "Mul only supports FLOAT32 and quantized UINT8 now, got %d.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc index 692da817272958fdb4a789d04bf43ed4f79731b4..4124c05388cca180c2b417603e6d239f1f97b5bf 100644 --- a/tensorflow/contrib/lite/kernels/neg.cc +++ b/tensorflow/contrib/lite/kernels/neg.cc @@ -27,7 +27,7 @@ constexpr int kOutputTensor = 0; TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); output->type = input->type; @@ -44,7 +44,7 @@ void Negate(const T* in_data, int num_elements, T* out_data) { } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const int num_elements = NumElements(input); switch (input->type) { @@ -59,7 +59,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError( - context, "Neg only currently supports int64, int32, and float32.", + context, + "Neg only currently supports int64, int32, and float32, got %d.", input->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 9e1e4658e971eac89da7d96219223687b06ab22a..83668cb4ca87e9eb53ab4ba9e88f91e3315594de 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -45,9 +45,9 @@ struct PadContext { output = GetOutput(context, node, 0); dims = NumDimensions(input); } - TfLiteTensor* constant_values; - TfLiteTensor* input; - TfLiteTensor* paddings; + const TfLiteTensor* constant_values; + const TfLiteTensor* input; + const TfLiteTensor* paddings; TfLiteTensor* output; int dims; }; @@ -199,7 +199,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } } break; default: - context->ReportError(context, "Type is currently not supported by Pad."); + context->ReportError(context, + "Type %d is currently not supported by Pad.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_PAD diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h index e81b970e0fb149e8c5d95ed12622917fdc336f7a..3cb55f19a99f3e54c2ef14b8b890b286ad25f3d1 100644 --- a/tensorflow/contrib/lite/kernels/padding.h +++ b/tensorflow/contrib/lite/kernels/padding.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#include "tensorflow/contrib/lite/builtin_op_data.h" + namespace tflite { inline int ComputePadding(int stride, int dilation_rate, int in_size, @@ -24,6 +26,33 @@ inline int ComputePadding(int stride, int dilation_rate, int in_size, return padding > 0 ? padding : 0; } +// Matching GetWindowedOutputSize in TensorFlow. +inline int ComputeOutSize(TfLitePadding padding, int image_size, + int filter_size, int stride) { + switch (padding) { + case kTfLitePaddingSame: + return (image_size + stride - 1) / stride; + case kTfLitePaddingValid: + return (image_size + stride - filter_size) / stride; + default: + return 0; + } +} + +inline TfLitePaddingValues ComputePaddingHeightWidth( + int stride_height, int stride_width, int dilation_rate, int in_height, + int in_width, int filter_height, int filter_width, TfLitePadding padding) { + int out_width = ComputeOutSize(padding, in_width, filter_width, stride_width); + int out_height = + ComputeOutSize(padding, in_height, filter_height, stride_height); + + TfLitePaddingValues padding_values; + padding_values.height = + ComputePadding(stride_height, 1, in_height, filter_height, out_height); + padding_values.width = + ComputePadding(stride_width, 1, in_width, filter_width, out_width); + return padding_values; +} } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc index 0bf27c34c1337b4ae4b8b73ee2dafcc931c7ce3c..311e9b8399726d758182e1f084a890d6f10e57ce 100644 --- a/tensorflow/contrib/lite/kernels/pooling.cc +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -69,7 +69,7 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TfLiteTensor* output = GetOutput(context, node, 0); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, input->type, output->type); @@ -122,7 +122,7 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { template void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, - TfLiteTensor* input, TfLiteTensor* output) { + const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; CalculateActivationRangeFloat(params->activation, &activation_min, &activation_max); @@ -143,7 +143,7 @@ void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node, template void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, - TfLiteTensor* input, TfLiteTensor* output) { + const TfLiteTensor* input, TfLiteTensor* output) { int32_t activation_min; int32_t activation_max; CalculateActivationRangeUint8(params->activation, output, &activation_min, @@ -165,8 +165,8 @@ void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node, template void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLitePoolParams* params, OpData* data, TfLiteTensor* input, - TfLiteTensor* output) { + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; CalculateActivationRangeFloat(params->activation, &activation_min, &activation_max); @@ -187,7 +187,7 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, template void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, - TfLiteTensor* input, TfLiteTensor* output) { + const TfLiteTensor* input, TfLiteTensor* output) { int32_t activation_min; int32_t activation_max; CalculateActivationRangeUint8(params->activation, output, &activation_min, @@ -209,8 +209,8 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, template void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLitePoolParams* params, OpData* data, TfLiteTensor* input, - TfLiteTensor* output) { + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; CalculateActivationRangeFloat(params->activation, &activation_min, &activation_max); @@ -236,7 +236,7 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* output = GetOutput(context, node, 0); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: AverageEvalFloat(context, node, params, data, input, output); @@ -246,7 +246,8 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { output); break; default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } return kTfLiteOk; @@ -258,7 +259,7 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* output = GetOutput(context, node, 0); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: MaxEvalFloat(context, node, params, data, input, output); @@ -267,7 +268,8 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { MaxEvalQuantized(context, node, params, data, input, output); break; default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } return kTfLiteOk; @@ -279,7 +281,7 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) { OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* output = GetOutput(context, node, 0); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: L2EvalFloat(context, node, params, data, input, output); @@ -288,7 +290,8 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) { // We don't have a quantized implementation, so just fall through to the // 'default' case. default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 5df35aac62141f9516e6ce5b31220951f2b0accb..7bb28d4de7402a45954691a2e031e3b6b7433ffb 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -73,6 +73,7 @@ TfLiteRegistration* Register_SQUEEZE(); TfLiteRegistration* Register_STRIDED_SLICE(); TfLiteRegistration* Register_EXP(); TfLiteRegistration* Register_TOPK_V2(); +TfLiteRegistration* Register_LOG(); TfLiteRegistration* Register_LOG_SOFTMAX(); TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_DEQUANTIZE(); @@ -85,8 +86,16 @@ TfLiteRegistration* Register_GREATER_EQUAL(); TfLiteRegistration* Register_LESS(); TfLiteRegistration* Register_LESS_EQUAL(); TfLiteRegistration* Register_FLOOR(); +TfLiteRegistration* Register_TILE(); TfLiteRegistration* Register_NEG(); TfLiteRegistration* Register_SELECT(); +TfLiteRegistration* Register_SLICE(); +TfLiteRegistration* Register_SIN(); +TfLiteRegistration* Register_TRANSPOSE_CONV(); +TfLiteRegistration* Register_EXPAND_DIMS(); +TfLiteRegistration* Register_SPARSE_TO_DENSE(); +TfLiteRegistration* Register_EQUAL(); +TfLiteRegistration* Register_NOT_EQUAL(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -120,7 +129,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORMALIZATION()); - AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, Register_BIDIRECTIONAL_SEQUENCE_LSTM()); AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, @@ -141,6 +151,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); + AddBuiltin(BuiltinOperator_LOG, Register_LOG()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE()); @@ -155,6 +166,14 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR()); AddBuiltin(BuiltinOperator_NEG, Register_NEG()); AddBuiltin(BuiltinOperator_SELECT, Register_SELECT()); + AddBuiltin(BuiltinOperator_SLICE, Register_SLICE()); + AddBuiltin(BuiltinOperator_SIN, Register_SIN()); + AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); + AddBuiltin(BuiltinOperator_TILE, Register_TILE()); + AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); + AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); + AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL()); + AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. @@ -163,29 +182,6 @@ BuiltinOpResolver::BuiltinOpResolver() { tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); } -TfLiteRegistration* BuiltinOpResolver::FindOp( - tflite::BuiltinOperator op) const { - auto it = builtins_.find(op); - return it != builtins_.end() ? it->second : nullptr; -} - -TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op) const { - auto it = custom_ops_.find(op); - return it != custom_ops_.end() ? it->second : nullptr; -} - -void BuiltinOpResolver::AddBuiltin(tflite::BuiltinOperator op, - TfLiteRegistration* registration) { - registration->builtin_code = op; - builtins_.insert(std::make_pair(op, registration)); -} - -void BuiltinOpResolver::AddCustom(const char* name, - TfLiteRegistration* registration) { - registration->builtin_code = BuiltinOperator_CUSTOM; - custom_ops_.insert(std::make_pair(std::string(name), registration)); -} - } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h index b9cff0ae21086b44e0c920095d5f6c9668346f38..b928f1b302580d52f708bbf85dfcfc0f79ff1e69 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/contrib/lite/kernels/register.h @@ -23,24 +23,9 @@ namespace tflite { namespace ops { namespace builtin { -class BuiltinOpResolver : public OpResolver { +class BuiltinOpResolver : public MutableOpResolver { public: BuiltinOpResolver(); - TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override; - TfLiteRegistration* FindOp(const char* op) const override; - void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration); - void AddCustom(const char* name, TfLiteRegistration* registration); - - private: - struct BuiltinOperatorHasher { - size_t operator()(const tflite::BuiltinOperator& x) const { - return std::hash()(static_cast(x)); - } - }; - std::unordered_map - builtins_; - std::unordered_map custom_ops_; }; } // namespace builtin diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc index 438f70d3115130efe477a3ceeccd2e77108c979a..3287040695140e3e7921c9f517450b9416b050b6 100644 --- a/tensorflow/contrib/lite/kernels/reshape.cc +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -35,7 +35,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Tensorflow's Reshape allows one of the shape components to have the @@ -70,7 +70,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); memcpy(output->data.raw, input->data.raw, input->bytes); diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index 9e3e19c09a4012ebdadbc2a7c2ba06c4bfefd206..86c4cd3ee88013ca4174f444d0388bc036d9cde6 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -36,8 +36,10 @@ constexpr int kInputTensor = 0; constexpr int kSizeTensor = 1; constexpr int kOutputTensor = 0; -TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TfLiteTensor* input, - TfLiteTensor* size, TfLiteTensor* output) { +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + const TfLiteTensor* input, + const TfLiteTensor* size, + TfLiteTensor* output) { TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); output_size->data[0] = input->dims->data[0]; const int32* size_data = GetTensorData(size); @@ -51,20 +53,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* size = GetInput(context, node, kSizeTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* size = GetInput(context, node, kSizeTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // TODO(ahentz): Our current implementations rely on the inputs being 4D. TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1); - // TODO(ahentz): Our current implementations only support float32. - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32); // ResizeBilinear creates a float tensor even when the input is made of // integers. - output->type = kTfLiteFloat32; + output->type = input->type; if (!IsConstantTensor(size)) { SetTensorToDynamic(output); @@ -78,9 +78,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteTensor* size = GetInput(context, node, kSizeTensor); + const TfLiteTensor* size = GetInput(context, node, kSizeTensor); if (IsDynamicTensor(output)) { TF_LITE_ENSURE_OK(context, @@ -88,21 +88,29 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } if (output->type == kTfLiteFloat32) { -#define TF_LITE_RESIZE_BILINEAR(type) \ - type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ - GetTensorData(size), GetTensorDims(size), \ - GetTensorData(output), GetTensorDims(output), \ +#define TF_LITE_RESIZE_BILINEAR(type, datatype) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(size), GetTensorDims(size), \ + GetTensorData(output), GetTensorDims(output), \ params->align_corners) if (kernel_type == kReference) { - TF_LITE_RESIZE_BILINEAR(reference_ops); + TF_LITE_RESIZE_BILINEAR(reference_ops, float); } if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { - TF_LITE_RESIZE_BILINEAR(optimized_ops); + TF_LITE_RESIZE_BILINEAR(optimized_ops, float); + } + } else if (output->type == kTfLiteUInt8) { + if (kernel_type == kReference) { + TF_LITE_RESIZE_BILINEAR(reference_ops, uint8_t); + } + if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { + TF_LITE_RESIZE_BILINEAR(optimized_ops, uint8_t); } #undef TF_LITE_RESIZE_BILINEAR } else { - context->ReportError(context, "Inputs and outputs not all float types."); + context->ReportError(context, "Output type is %d, requires float.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 4e03f3820a5c14ee1692c553db61e385716b1723..10caffea03ebcec7862df1627541ac3d076b04e4 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -22,6 +22,7 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using uint8 = std::uint8_t; class ResizeBilinearOpModel : public SingleOpModel { public: @@ -34,7 +35,7 @@ class ResizeBilinearOpModel : public SingleOpModel { } else { size_ = AddInput({TensorType_INT32, {2}}); } - output_ = AddOutput(TensorType_FLOAT32); // Always float. + output_ = AddOutput(input.type); SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, CreateResizeBilinearOptions(builder_).Union()); @@ -45,12 +46,16 @@ class ResizeBilinearOpModel : public SingleOpModel { } } - void SetInput(std::initializer_list data) { + template + void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } void SetSize(std::initializer_list data) { PopulateTensor(size_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + std::vector GetOutput() { + return ExtractVector(output_); + } private: int input_; @@ -60,60 +65,121 @@ class ResizeBilinearOpModel : public SingleOpModel { TEST(ResizeBilinearOpTest, HorizontalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); - m.SetInput({3, 6}); + m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); - const_m.SetInput({3, 6}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(ResizeBilinearOpTest, HorizontalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}}); + m.SetInput({3, 6}); + m.SetSize({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } TEST(ResizeBilinearOpTest, VerticalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); - m.SetInput({3, 9}); + m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); - const_m.SetInput({3, 9}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(ResizeBilinearOpTest, VerticalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}}); + m.SetInput({3, 9}); + m.SetSize({3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } TEST(ResizeBilinearOpTest, TwoDimensionalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); - m.SetInput({ + m.SetInput({ 3, 6, // 9, 12 // }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 6, // 9, 12 // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); } TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}); - m.SetInput({ + m.SetInput({ 3, 6, // 9, 12, // 4, 10, // @@ -121,60 +187,123 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - 4, 8, 10, // - 8, 12, 14, // - 10, 14, 16, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 6, // 9, 12, // 4, 10, // 10, 16 // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - 4, 8, 10, // - 8, 12, 14, // - 10, 14, 16, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}); - m.SetInput({ + m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 14, 12, 16, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 14, 12, 16, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 13, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 13, 16, // + }))); } +TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}); + m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 13, 12, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3}); + const_m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 13, 12, 16, // + }))); +} } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc index 029ad9a709c514985e9944e646f70094693200b9..9b6cee3cb55bf93b987fa8e59bdf9c591f5c0372 100644 --- a/tensorflow/contrib/lite/kernels/select.cc +++ b/tensorflow/contrib/lite/kernels/select.cc @@ -33,10 +33,10 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input_condition = + const TfLiteTensor* input_condition = GetInput(context, node, kInputTensorCondition); - TfLiteTensor* input_x = GetInput(context, node, kInputTensorX); - TfLiteTensor* input_y = GetInput(context, node, kInputTensorY); + const TfLiteTensor* input_x = GetInput(context, node, kInputTensorX); + const TfLiteTensor* input_y = GetInput(context, node, kInputTensorY); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Input must be bool. @@ -62,10 +62,10 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input_condition = + const TfLiteTensor* input_condition = GetInput(context, node, kInputTensorCondition); - TfLiteTensor* input_x = GetInput(context, node, kInputTensorX); - TfLiteTensor* input_y = GetInput(context, node, kInputTensorY); + const TfLiteTensor* input_x = GetInput(context, node, kInputTensorX); + const TfLiteTensor* input_y = GetInput(context, node, kInputTensorY); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool is_rank_one = !HaveSameShapes(input_condition, input_x); @@ -97,7 +97,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) { break; \ default: \ context->ReportError(context, \ - "Does not support type other than bool|float|int"); \ + "Does not support type other than bool|float|int, " \ + "got %d", \ + type); \ return kTfLiteError; \ } diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a20e802a99cdf23a005a8cd9f1fd97b03c8070a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/slice.cc @@ -0,0 +1,201 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace slice { + +constexpr int kInputTensor = 0; +constexpr int kBeginTensor = 1; +constexpr int kSizeTensor = 2; +constexpr int kOutputTensor = 0; + +// This Op only supports 1-4D cases and since we use the optimized ops 4D +// implementation, the 1-3D tensors are mapped to 4D. +const int kMaxDim = 4; + +template +TfLiteStatus CalculateOutputShapeVector( + TfLiteContext* context, const TfLiteTensor* input, + const TfLiteTensor* begin, const TfLiteTensor* size, + std::vector* output_shape_vector) { + for (int idx = 0; idx < NumDimensions(input); ++idx) { + T size_value = GetTensorData(size)[idx]; + if (size_value < 0) { + if (size_value != -1) { + context->ReportError(context, "Invalid size."); + return kTfLiteError; + } + size_value = SizeOfDimension(input, idx) - GetTensorData(begin)[idx]; + } else { + if (SizeOfDimension(input, idx) < + GetTensorData(begin)[idx] + size_value) { + context->ReportError(context, "Invalid begin and size."); + return kTfLiteError; + } + } + output_shape_vector->push_back(size_value); + } + return kTfLiteOk; +} + +template +void GetBeginAndSizeVectors(int dimensions, const TfLiteTensor* begin, + const TfLiteTensor* size, std::vector* begins, + std::vector* sizes) { + for (int idx = dimensions - 1; idx >= 0; --idx) { + begins->push_back(GetTensorData(begin)[idx]); + sizes->push_back(GetTensorData(size)[idx]); + } +} + +TfLiteStatus ResizeOutputShape(TfLiteContext* context, + const TfLiteTensor* input, + const TfLiteTensor* begin, + const TfLiteTensor* size, TfLiteTensor* output) { + std::vector output_shape_vector; + + if (begin->type == kTfLiteInt32) { + TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector( + context, input, begin, size, &output_shape_vector)); + } else if (begin->type == kTfLiteInt64) { + TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector( + context, input, begin, size, &output_shape_vector)); + } else { + context->ReportError( + context, "Type %d is currently not supported by Slice.", begin->type); + return kTfLiteError; + } + + TfLiteIntArray* output_shape = + TfLiteIntArrayCreate(output_shape_vector.size()); + std::copy(output_shape_vector.begin(), output_shape_vector.end(), + output_shape->data); + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* begin = GetInput(context, node, kBeginTensor); + const TfLiteTensor* size = GetInput(context, node, kSizeTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Ensure validity of input tensor and its dimension. + TF_LITE_ENSURE_EQ(context, input->type, output->type); + TF_LITE_ENSURE(context, + begin->type == kTfLiteInt32 || begin->type == kTfLiteInt64); + TF_LITE_ENSURE(context, + size->type == kTfLiteInt32 || size->type == kTfLiteInt64); + TF_LITE_ENSURE(context, NumDimensions(begin) == NumDimensions(size) == 1); + TF_LITE_ENSURE_MSG(context, NumDimensions(input) <= kMaxDim, + "Slice op only supports 1D-4D input arrays."); + + // Postpone allocation of output if any of the indexing tensors is not + // constant + if (!(IsConstantTensor(begin) && IsConstantTensor(size))) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + + return ResizeOutputShape(context, input, begin, size, output); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* begin = GetInput(context, node, kBeginTensor); + const TfLiteTensor* size = GetInput(context, node, kSizeTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputShape(context, input, begin, size, output)); + } + + std::vector begins; + begins.reserve(kMaxDim); + std::vector sizes; + sizes.reserve(kMaxDim); + + if (begin->type == kTfLiteInt32) { + GetBeginAndSizeVectors(NumDimensions(input), begin, size, &begins, + &sizes); + } else if (begin->type == kTfLiteInt64) { + GetBeginAndSizeVectors(NumDimensions(input), begin, size, &begins, + &sizes); + } else { + context->ReportError( + context, "Type %d is currently not supported by Slice.", begin->type); + return kTfLiteError; + } + + for (int i = NumDimensions(input); i < kMaxDim; ++i) { + begins.push_back(0); + sizes.push_back(1); + } + +#define TF_LITE_SLICE(data_type) \ + optimized_ops::Slice( \ + GetTensorData(input), GetTensorDims(input), begins, sizes, \ + GetTensorData(output), GetTensorDims(output)) + + switch (input->type) { + case kTfLiteFloat32: + TF_LITE_SLICE(float); + break; + case kTfLiteInt32: + TF_LITE_SLICE(int32_t); + break; + case kTfLiteInt64: + TF_LITE_SLICE(int64_t); + break; + case kTfLiteUInt8: + TF_LITE_SLICE(uint8_t); + break; + case kTfLiteBool: + TF_LITE_SLICE(bool); + break; + default: + context->ReportError( + context, "Type %d is currently not supported by Slice.", input->type); + return kTfLiteError; + } +#undef TF_LITE_SLICE + return kTfLiteOk; +} + +} // namespace slice + +TfLiteRegistration* Register_SLICE() { + static TfLiteRegistration r = {nullptr, nullptr, slice::Prepare, slice::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/slice_test.cc b/tensorflow/contrib/lite/kernels/slice_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4828f88f36bc1e7daf84ab6831a2ccc98bfaed40 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/slice_test.cc @@ -0,0 +1,173 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class SliceOpModel : public SingleOpModel { + public: + SliceOpModel(std::initializer_list input_shape, + std::initializer_list begin_shape, + std::initializer_list size_shape, + TensorType tensor_index_type, TensorType tensor_input_type) { + input_ = AddInput(tensor_input_type); + begin_ = AddInput(tensor_index_type); + size_ = AddInput(tensor_index_type); + output_ = AddOutput(tensor_input_type); + SetBuiltinOp(BuiltinOperator_SLICE, BuiltinOptions_SliceOptions, + CreateSliceOptions(builder_).Union()); + BuildInterpreter({input_shape, begin_shape, size_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetBegin(std::initializer_list data) { + PopulateTensor(begin_, data); + } + void SetSize(std::initializer_list data) { + PopulateTensor(size_, data); + } + + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int begin_; + int size_; + int output_; +}; + +TEST(SliceOpTest, In1D) { + SliceOpModel m({4}, {1}, {1}, TensorType_INT32, + TensorType_FLOAT32); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetSize({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); +} + +TEST(SliceOpTest, In2D) { + SliceOpModel m({2, 3}, {2}, {2}, TensorType_INT32, + TensorType_FLOAT32); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetSize({1, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5})); +} + +TEST(SliceOpTest, In3D) { + SliceOpModel m({2, 3, 2}, {3}, {4}, TensorType_INT32, + TensorType_FLOAT32); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetSize({2, 3, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); +} + +TEST(SliceOpTest, InputFloat) { + SliceOpModel m({4, 1, 1, 1}, {4}, {4}, TensorType_INT32, + TensorType_FLOAT32); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1, 0, 0, 0}); + m.SetSize({3, 1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); +} + +TEST(SliceOpTest, IndexInt64) { + SliceOpModel m({4, 1, 1, 1}, {4}, {4}, TensorType_INT64, + TensorType_FLOAT32); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1, 0, 0, 0}); + m.SetSize({3, 1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); +} + +// See these test cases under: +// https://www.tensorflow.org/versions/master/api_docs/python/tf/slice +TEST(SliceOpTest, InputInteger1) { + SliceOpModel m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32, + TensorType_INT32); + m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + m.SetBegin({1, 0, 0, 0}); + m.SetSize({1, 1, 3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3})); +} + +TEST(SliceOpTest, InputInteger2) { + SliceOpModel m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32, + TensorType_INT32); + m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + m.SetBegin({1, 0, 0, 0}); + m.SetSize({1, 2, 3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 4, 4, 4})); +} + +TEST(SliceOpTest, InputInteger3) { + SliceOpModel m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32, + TensorType_INT32); + m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + m.SetBegin({1, 0, 0, 0}); + m.SetSize({2, 1, 3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); +} + +TEST(SliceOpTest, SizeMinus1) { + SliceOpModel m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32, + TensorType_INT32); + m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + m.SetBegin({1, 0, 0, 0}); + m.SetSize({2, 1, -1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc index d8c9e352f00627eee45ae836b720f2af77140538..c9269599e58f095ded4788e2ab064583ae0a708c 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -40,9 +40,9 @@ struct SpaceToBatchNDContext { paddings = GetInput(context, node, 2); output = GetOutput(context, node, 0); } - TfLiteTensor* input; - TfLiteTensor* block_shape; - TfLiteTensor* paddings; + const TfLiteTensor* input; + const TfLiteTensor* block_shape; + const TfLiteTensor* paddings; TfLiteTensor* output; }; @@ -152,8 +152,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } break; default: - context->ReportError(context, - "Type is currently not supported by SpaceToBatch."); + context->ReportError( + context, "Type %d is currently not supported by SpaceToBatch.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_SPACE_TO_BATCH_ND diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc index cb2e509c9811b1469c4d3f676532edff570a6c4a..9dbe9b9edaccc3ea75f1997378aba5a218ee3030 100644 --- a/tensorflow/contrib/lite/kernels/space_to_depth.cc +++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc @@ -42,7 +42,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); @@ -76,7 +76,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); #define TF_LITE_SPACE_TO_DEPTH(type, scalar) \ @@ -113,7 +113,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } break; default: - context->ReportError(context, "Type not currently supported."); + context->ReportError(context, "Type %d not currently supported.", + input->type); return kTfLiteError; } #undef TF_LITE_SPACE_TO_DEPTH diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc new file mode 100644 index 0000000000000000000000000000000000000000..404c32ad9ca8b9f1e467b747708ccb451f2a5118 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc @@ -0,0 +1,275 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace sparse_to_dense { + +constexpr int kIndicesTensor = 0; +constexpr int kOutputShapeTensor = 1; +constexpr int kValueInputTensor = 2; +constexpr int kDefaultValueTensor = 3; +constexpr int kOutputTensor = 0; + +constexpr int kMaxDimensions = 4; + +template +TfLiteStatus Resize(TfLiteContext* context, const TfLiteTensor* output_shape, + TfLiteTensor* output) { + const int output_dimensions = NumElements(output_shape); + TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(output_dimensions); + for (int i = 0; i < output_dimensions; ++i) { + output_shape_array->data[i] = GetTensorData(output_shape)[i]; + } + + return context->ResizeTensor(context, output, output_shape_array); +} + +TfLiteStatus CheckDimensionsMatch(TfLiteContext* context, + const TfLiteTensor* indices, + const TfLiteTensor* output_shape, + const TfLiteTensor* values) { + switch (NumDimensions(indices)) { + case 0: + case 1: { + if (NumDimensions(values) == 0) { + TF_LITE_ENSURE_EQ(context, NumElements(indices), NumElements(values)); + } + TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 1); + break; + } + case 2: { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 1), + NumElements(output_shape)); + if (NumDimensions(values) == 0) + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + NumElements(values)); + break; + } + default: + context->ReportError( + context, "Wrong indices dimensions %d, should be less than 3.", + NumDimensions(indices)); + return kTfLiteError; + } + return kTfLiteOk; +} + +// Convert indices into a vector of 4-d vectors. +// TODO(renjieliu): Revisit here to improve the performance, since multiple +// allocations of std::vectors will be quite slow on phones. +template +TfLiteStatus GetIndicesVector(TfLiteContext* context, + const TfLiteTensor* indices, + const int num_indices, + std::vector>* indices_vector) { + // Note because TfLite will reverse the dimensions, so pad zeros upfront. + switch (NumDimensions(indices)) { + case 0: + case 1: { + const auto indices_data = GetTensorData(indices); + for (int i = 0; i < num_indices; ++i) { + std::vector index({0, 0, 0, indices_data[i]}); + indices_vector->push_back(index); + } + break; + } + case 2: { + const int true_dimensions = SizeOfDimension(indices, 1); + TF_LITE_ENSURE(context, true_dimensions <= kMaxDimensions); + for (int i = 0; i < num_indices; ++i) { + std::vector index; + index.reserve(kMaxDimensions); + // Fill the index with 1 up to kMaxDimensions - true_dimensions to + // satisfy the needs for 4-dimension index. + for (int j = 0; j < kMaxDimensions - true_dimensions; ++j) { + index.push_back(0); + } + for (int j = 0; j < true_dimensions; ++j) { + index.push_back(GetTensorData(indices)[i * true_dimensions + j]); + } + + indices_vector->push_back(index); + } + break; + } + default: + context->ReportError(context, + "Indices dimensions problem, got %d dimensions", + NumDimensions(indices)); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus ResizeOutputShape(TfLiteContext* context, + const TfLiteTensor* output_shape, + TfLiteTensor* output) { + if (output_shape->type == kTfLiteInt32) { + return Resize(context, output_shape, output); + } else if (output_shape->type == kTfLiteInt64) { + return Resize(context, output_shape, output); + } else { + context->ReportError(context, "Dense shape type %d not supported.", + output_shape->type); + return kTfLiteError; + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + const TfLiteTensor* default_value = + GetInput(context, node, kDefaultValueTensor); + + // TODO(renjieliu): Handle validate_indices. + + // Indices can be 0-D, 1-D or 2-D. + TF_LITE_ASSERT(NumDimensions(indices) >= 0); + TF_LITE_ENSURE(context, NumDimensions(indices) < 3); + TF_LITE_ASSERT(NumDimensions(output_shape) >= 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); + // Values can be 0-D or 1-D. + TF_LITE_ASSERT(NumDimensions(values) >= 0); + TF_LITE_ENSURE(context, NumDimensions(values) < 2); + + TF_LITE_ENSURE_EQ(context, NumElements(default_value), 1); + + TF_LITE_ENSURE( + context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64); + TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 || + output_shape->type == kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, values->type, default_value->type); + + // Ensure dimensions match. + TF_LITE_ENSURE_OK( + context, CheckDimensionsMatch(context, indices, output_shape, values)); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); + + if (!IsConstantTensor(output_shape)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputShape(context, output_shape, output); +} + +template +TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + const TfLiteTensor* default_value = + GetInput(context, node, kDefaultValueTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputShape(context, output_shape, output)); + } + + const int num_indices = SizeOfDimension(indices, 0); + const bool value_is_scalar = NumDimensions(values) == 0; + std::vector> indices_vector; + indices_vector.reserve(num_indices); + TF_LITE_ENSURE_OK(context, GetIndicesVector(context, indices, num_indices, + &indices_vector)); + reference_ops::SparseToDense(indices_vector, GetTensorData(values), + *GetTensorData(default_value), + GetTensorData(output), GetTensorDims(output), + value_is_scalar); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + + // Currently only supports float32 and int32. + switch (values->type) { + case kTfLiteFloat32: { + switch (indices->type) { + case kTfLiteInt32: { + return SparseToDenseImpl(context, node); + } + case kTfLiteInt64: { + return SparseToDenseImpl(context, node); + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + indices->type); + return kTfLiteError; + } + break; + } + case kTfLiteInt32: { + switch (indices->type) { + case kTfLiteInt32: { + return SparseToDenseImpl(context, node); + } + case kTfLiteInt64: { + return SparseToDenseImpl(context, node); + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + indices->type); + return kTfLiteError; + } + break; + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + values->type); + return kTfLiteError; + } +} + +} // namespace sparse_to_dense + +TfLiteRegistration* Register_SPARSE_TO_DENSE() { + static TfLiteRegistration r = {nullptr, nullptr, sparse_to_dense::Prepare, + sparse_to_dense::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a51ec17afcefd791680d7aa42cef467f481f6dbc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc @@ -0,0 +1,155 @@ + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class SparseToDenseOpModel : public SingleOpModel { + public: + SparseToDenseOpModel(std::initializer_list indices_shape, + std::initializer_list output_shape_shape, + std::initializer_list values_shape, T default_value, + TensorType tensor_index_type, + TensorType tensor_input_type) { + indices_ = AddInput(tensor_index_type); + output_shape_ = AddInput(TensorType_INT32); + values_ = AddInput(tensor_input_type); + default_value_ = AddInput(tensor_input_type); + output_ = AddOutput(tensor_input_type); + + SetBuiltinOp(BuiltinOperator_SPARSE_TO_DENSE, + BuiltinOptions_SparseToDenseOptions, + CreateSparseToDenseOptions(builder_, false).Union()); + BuildInterpreter({indices_shape, output_shape_shape, values_shape, {1}}); + + PopulateTensor(default_value_, {default_value}); + } + + int indices() { return indices_; } + int output_shape() { return output_shape_; } + int values() { return values_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int indices_; + int output_shape_; + int values_; + int default_value_; + int output_; +}; + +TEST(SparseToDenseOpModelTest, ZeroDimensionTest) { + SparseToDenseOpModel m({1}, {1}, {1}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {3}); + m.PopulateTensor(m.output_shape(), {5}); + m.PopulateTensor(m.values(), {7}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 7, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(SparseToDenseOpModelTest, OneDimensionTest) { + SparseToDenseOpModel m({3}, {1}, {3}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {1, 3, 5}); + m.PopulateTensor(m.output_shape(), {7}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 0, 4, 0, 6, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({7})); +} + +TEST(SparseToDenseOpModelTest, TwoDimensionsTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 4, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, DefaultValueTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, IntegerValueTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, + TensorType_INT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, Int64IndexTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT64, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc index b524c79f8779b0119781679c0af9fe354e38ad4f..43387df9ceb4d54a2784c3fa4718a95262948729 100644 --- a/tensorflow/contrib/lite/kernels/split.cc +++ b/tensorflow/contrib/lite/kernels/split.cc @@ -34,8 +34,8 @@ struct OpContext { input = GetInput(context, node, 1); } TfLiteSplitParams* params; - TfLiteTensor* axis; - TfLiteTensor* input; + const TfLiteTensor* axis; + const TfLiteTensor* input; }; TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) { @@ -46,8 +46,8 @@ TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node, - TfLiteTensor* axis, TfLiteTensor* input, - int num_splits) { + const TfLiteTensor* axis, + const TfLiteTensor* input, int num_splits) { int axis_value = GetTensorData(axis)[0]; if (axis_value < 0) { axis_value += NumDimensions(input); @@ -138,8 +138,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; } default: - context->ReportError(context, - "Only float32 and uint8 are currently supported."); + context->ReportError( + context, "Only float32 and uint8 are currently supported, got %d.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_SPLIT diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc index 29447ab021c7b68ff51070d35262402e08dc7ab9..09a5662fd9e70da700c629d94453cb87ad37c448 100644 --- a/tensorflow/contrib/lite/kernels/squeeze.cc +++ b/tensorflow/contrib/lite/kernels/squeeze.cc @@ -26,13 +26,12 @@ namespace builtin { namespace squeeze { struct SqueezeContext { - SqueezeContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); - input = GetInput(context, node, 0); - output = GetOutput(context, node, 0); - } + SqueezeContext(TfLiteContext* context, TfLiteNode* node) + : params(reinterpret_cast(node->builtin_data)), + input(GetInput(context, node, 0)), + output(GetOutput(context, node, 0)) {} TfLiteSqueezeParams* params; - TfLiteTensor* input; + const TfLiteTensor* const input; TfLiteTensor* output; }; diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index 40ac436b7dcabe7a12166e5381f0381941a204d3..725dd8105ab9506d5203ed38a11f8e06abdab603 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -49,10 +49,10 @@ struct StridedSliceContext { dims = NumDimensions(input); } const TfLiteStridedSliceParams* params; - TfLiteTensor* input; - TfLiteTensor* begin; - TfLiteTensor* end; - TfLiteTensor* strides; + const TfLiteTensor* input; + const TfLiteTensor* begin; + const TfLiteTensor* end; + const TfLiteTensor* strides; TfLiteTensor* output; int dims; }; @@ -235,8 +235,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Type is currently not supported " - "by StridedSlice."); + "Type %d is currently not supported " + "by StridedSlice.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_STRIDED_SLICE diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index 7c60a4fdbffdc96b8967f52f8dbab3e18ecbcc0a..bdcaab8e2fa8a3342e0958635ec77a15a7238ccf 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -57,8 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, input1->type, input2->type); @@ -80,7 +80,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { template void EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, const OpData* data, - TfLiteTensor* input1, TfLiteTensor* input2, + const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, @@ -109,7 +109,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, template void EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, const OpData* data, - TfLiteTensor* input1, TfLiteTensor* input2, + const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { auto input1_offset = -input1->params.zero_point; auto input2_offset = -input2->params.zero_point; @@ -164,8 +164,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); - TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { @@ -174,8 +174,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { EvalQuantized(context, node, params, data, input1, input2, output); } else { - context->ReportError(context, - "Inputs and outputs not all float|uint8 types."); + context->ReportError( + context, "output type %d is not supported, requires float|uint8 types.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 13da51c7a78c362160e2ecd6121aa31bc7ce5355..308860c299e9d74729d35b760e0f605437872c92 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -58,9 +58,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; - TfLiteTensor* weights_feature = + const TfLiteTensor* weights_feature = GetInput(context, node, kWeightsFeatureTensor); - TfLiteTensor* weights_time = GetInput(context, node, kWeightsTimeTensor); + const TfLiteTensor* weights_time = + GetInput(context, node, kWeightsTimeTensor); // Check all the parameters of tensor match within themselves and match the // input configuration. @@ -73,7 +74,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ASSERT_EQ(input->dims->data[1], weights_feature->dims->data[1]); TF_LITE_ASSERT_EQ(weights_time->dims->data[0], num_filters); - TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); if (bias) { TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); } @@ -123,16 +124,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* weights_feature = + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* weights_feature = GetInput(context, node, kWeightsFeatureTensor); - TfLiteTensor* weights_time = GetInput(context, node, kWeightsTimeTensor); + const TfLiteTensor* weights_time = + GetInput(context, node, kWeightsTimeTensor); TfLiteTensor* state = GetOutput(context, node, kStateTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); - TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); const int rank = params->rank; const int batch_size = input->dims->data[0]; diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index 5a6c85e97ef5f0c015dcd6cb89dba85cdf4ae937..d23ec201b41887b0682242687fc938d76d058c44 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -101,7 +101,7 @@ void SingleOpModel::BuildInterpreter( } resolver_ = std::unique_ptr(resolver); } - InterpreterBuilder(model, *resolver_)(&interpreter_); + CHECK(InterpreterBuilder(model, *resolver_)(&interpreter_) == kTfLiteOk); CHECK(interpreter_ != nullptr); @@ -112,6 +112,12 @@ void SingleOpModel::BuildInterpreter( if (shape.empty()) continue; CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk); } + + // Modify delegate with function. + if (apply_delegate_fn_) { + apply_delegate_fn_(interpreter_.get()); + } + CHECK(interpreter_->AllocateTensors() == kTfLiteOk) << "Cannot allocate tensors"; } diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index 6a9fdf11122da5c9039a739036bd19f431149ca8..db80c0082c394a2cb2f9388d3db5bd1a7cbe6266 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -89,18 +89,24 @@ struct TensorData { class SingleOpResolver : public OpResolver { public: SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration) - : op_(op), registration_(registration) {} - TfLiteRegistration* FindOp(BuiltinOperator op) const override { + : op_(op), registration_(*registration) { + registration_.builtin_code = static_cast(op); + registration_.version = 1; + } + const TfLiteRegistration* FindOp(BuiltinOperator op, + int version) const override { if (op == op_) { - return registration_; + return ®istration_; } return nullptr; } - TfLiteRegistration* FindOp(const char* op) const override { return nullptr; } + const TfLiteRegistration* FindOp(const char* op, int version) const override { + return nullptr; + } private: const BuiltinOperator op_; - TfLiteRegistration* registration_; + TfLiteRegistration registration_; }; class SingleOpModel { @@ -108,6 +114,13 @@ class SingleOpModel { SingleOpModel() {} ~SingleOpModel() {} + // Set a function callback that is run right after graph is prepared + // that allows applying external delegates. This is useful for testing + // other runtimes like NN API or GPU. + void SetApplyDelegate(std::function apply_delegate_fn) { + apply_delegate_fn_ = apply_delegate_fn; + } + // Copying or assignment is disallowed to simplify ownership semantics. SingleOpModel(const SingleOpModel&) = delete; SingleOpModel& operator=(const SingleOpModel&) = delete; @@ -311,6 +324,9 @@ class SingleOpModel { std::vector> operators_; std::vector> buffers_; std::map> custom_registrations_; + // A function pointer that gets called after the interpreter is created but + // before evaluation happens. This is useful for applying a delegate. + std::function apply_delegate_fn_; }; // Base class for single op unit tests. diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc new file mode 100644 index 0000000000000000000000000000000000000000..af77f074742eb3fef10a74616ff679255911fbb2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/tile.cc @@ -0,0 +1,194 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +namespace tflite { +namespace ops { +namespace builtin { +namespace tile { + +constexpr int kInputTensor = 0; +constexpr int kInputMultipliers = 1; +constexpr int kOutputTensor = 0; + +namespace { +template +TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape, + const TfLiteTensor* multipliers, + int num_dimensions) { + const T* multipliers_v = GetTensorData(multipliers); + + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + for (int i = 0; i < num_dimensions; ++i) { + output_shape->data[i] = shape.data[i] * multipliers_v[i]; + } + return output_shape; +} + +TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + + const int num_dimensions = NumDimensions(input); + const int num_multipliers = NumElements(multipliers); + TF_LITE_ENSURE_EQ(context, num_dimensions, num_multipliers); + switch (multipliers->type) { + case kTfLiteInt32: + return context->ResizeTensor( + context, output, + MultiplyShapeDims(*input->dims, multipliers, + num_dimensions)); + case kTfLiteInt64: + return context->ResizeTensor( + context, output, + MultiplyShapeDims(*input->dims, multipliers, + num_dimensions)); + default: + context->ReportError(context, "Tile not supported multiply tensor type."); + return kTfLiteError; + } +} + +template +void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier, + T* out_data) { + for (int i = 0; i < multiplier; ++i) { + const T* in_end = in_data + in_size; + T* new_out_data = std::copy(in_data, in_end, out_data); + in_data = out_data; + out_data = new_out_data; + } +} + +template +std::pair TileOneDimension(const TfLiteIntArray& in_dimensions, + const T* in_data, const M* multipliers, + T* out_data, int dimension) { + const int dimension_size = in_dimensions.data[dimension]; + if (dimension == in_dimensions.size - 1) { + CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], + out_data); + return std::make_pair(dimension_size, + dimension_size * multipliers[dimension]); + } + int total_stride_size = 0, total_tiled_stride_size = 0; + const T* copy_from_data = in_data; + T* copy_to_data = out_data; + for (int i = 0; i < dimension_size; ++i) { + int stride_size = 0, tiled_stride_size = 0; + std::tie(stride_size, tiled_stride_size) = + TileOneDimension(in_dimensions, copy_from_data, multipliers, + copy_to_data, dimension + 1); + copy_from_data += stride_size; + copy_to_data += tiled_stride_size; + total_stride_size += stride_size; + total_tiled_stride_size += tiled_stride_size; + } + CopyMultipleTimes(out_data, total_tiled_stride_size, + multipliers[dimension] - 1, + out_data + total_tiled_stride_size); + return std::make_pair(total_stride_size, + total_tiled_stride_size * multipliers[dimension]); +} + +template +void Tile(const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data, + const TfLiteTensor* multipliers, TfLiteTensor* out_data) { + // Doing recursively tiling from top to down dimension. + switch (multipliers->type) { + case kTfLiteInt32: + TileOneDimension(in_dimensions, GetTensorData(in_data), + GetTensorData(multipliers), + GetTensorData(out_data), 0); + break; + case kTfLiteInt64: + TileOneDimension(in_dimensions, GetTensorData(in_data), + GetTensorData(multipliers), + GetTensorData(out_data), 0); + break; + default: + break; + } +} +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + // Only int32 and int64 multipliers type is supported. + TF_LITE_ENSURE_MSG(context, + (multipliers->type == kTfLiteInt32) || + (multipliers->type == kTfLiteInt64), + "Tile only supports int32 and int64 mutlipliers."); + + if (IsConstantTensor(multipliers)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } else { + SetTensorToDynamic(output); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } + + switch (output->type) { + case kTfLiteFloat32: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteUInt8: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteInt32: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteInt64: + Tile(*(input->dims), input, multipliers, output); + break; + default: + context->ReportError(context, "Type is currently not supported by Tile."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace tile +TfLiteRegistration* Register_TILE() { + static TfLiteRegistration r = {nullptr, nullptr, tile::Prepare, tile::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a134a75d56ae03a5d03a3cdf632146474b863971 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/tile_test.cc @@ -0,0 +1,256 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; +class TileOpModel : public SingleOpModel { + public: + TileOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType multiply_type) { + input_ = AddInput(input_type); + multipliers_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_TILE, BuiltinOptions_TileOptions, 0); + BuildInterpreter({input_shape, {static_cast(input_shape.size())}}); + } + + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputUInt8(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt32(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt64(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetMultipliers(std::initializer_list data) { + PopulateTensor(multipliers_, data); + } + + std::vector GetOutputFloat() { return ExtractVector(output_); } + + std::vector GetOutputUInt8() { return ExtractVector(output_); } + + std::vector GetOutputInt32() { return ExtractVector(output_); } + + std::vector GetOutputInt64() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int multipliers_; + int output_; +}; + +TEST(TileTest, Float32Vector) { + TileOpModel m({3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({1.f, 2.f, 3.f}); + m.SetMultipliers({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray({1.f, 2.f, 3.f, 1.f, 2.f, 3.f})); +} + +TEST(TileTest, Float32Matrix) { + TileOpModel m({2, 3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Float32HighDimension) { + TileOpModel m({1, 2, 3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + }); + m.SetMultipliers({2, 3, 1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutputFloat(), + ElementsAreArray({11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, + 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, + 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, + 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 6, 3})); +} + +TEST(TileTest, Uint8Matrix) { + TileOpModel m({2, 3}, TensorType_UINT8, TensorType_INT32); + m.SetInputUInt8({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputUInt8(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int32Matrix) { + TileOpModel m({2, 3}, TensorType_INT32, TensorType_INT32); + m.SetInputInt32({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt32(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int64Matrix) { + TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT32); + m.SetInputInt64({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int64Matrix64Multipliers) { + TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT64); + m.SetInputInt64({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc index ad9b744f1af2715a37cc60ef61b0b9540fe2532b..fb0e49c90c41747f9b7e53570276c8b8045030fd 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2.cc @@ -30,15 +30,14 @@ constexpr int kOutputIndexes = 1; namespace { TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* top_k = GetInput(context, node, kInputTopK); + const TfLiteTensor* top_k = GetInput(context, node, kInputTopK); // INT32 number of top results is supported. TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32); // Check that the tensor contains only one value. - TF_LITE_ENSURE_EQ(context, NumDimensions(top_k), 1); TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1); - const int32 k = top_k->data.i32[0]; + const int32 k = *GetTensorData(top_k); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const int num_dimensions = NumDimensions(input); // Check that input has one or more dimensions. TF_LITE_ENSURE_MSG(context, input->dims->size >= 1, @@ -162,11 +161,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); TF_LITE_ENSURE_EQ(context, input->type, output_values->type); - TfLiteTensor* top_k = GetInput(context, node, kInputTopK); + const TfLiteTensor* top_k = GetInput(context, node, kInputTopK); TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32); // Set output dynamic if the input is not const. @@ -187,11 +186,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (IsDynamicTensor(output_values)) { TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); } - TfLiteTensor* top_k = GetInput(context, node, kInputTopK); + const TfLiteTensor* top_k = GetInput(context, node, kInputTopK); const int32 k = top_k->data.i32[0]; // The tensor can have more than 2 dimensions or even be a vector, the code // anyway calls the internal dimension as row; - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const int32 row_size = input->dims->data[input->dims->size - 1]; int32 num_rows = 1; for (int i = 0; i < input->dims->size - 1; ++i) { @@ -215,7 +214,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output_values->data.i64); break; default: - context->ReportError(context, "Type is currently not supported by TopK."); + context->ReportError(context, + "Type %d is currently not supported by TopK.", + output_values->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc index d3c10a9bb7b07404ccd8cfe2636473a622b91787..800b0563d7ee6126d65005ff4ef61219db9eebb5 100644 --- a/tensorflow/contrib/lite/kernels/transpose.cc +++ b/tensorflow/contrib/lite/kernels/transpose.cc @@ -37,8 +37,8 @@ struct TransposeContext { perm = GetInput(context, node, 1); output = GetOutput(context, node, 0); } - TfLiteTensor* input; - TfLiteTensor* perm; + const TfLiteTensor* input; + const TfLiteTensor* perm; TfLiteTensor* output; }; @@ -136,7 +136,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; default: context->ReportError(context, - "Type is currently not supported by Transpose."); + "Type %d is currently not supported by Transpose.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_TRANSPOSE diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..e83b1ec9879d3c360203a52835d8486d0a9b81bb --- /dev/null +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -0,0 +1,146 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace transpose_conv { + +constexpr int kOutputShapeTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kDataInputTensor = 2; +constexpr int kOutputTensor = 0; + +TfLiteStatus ResizeOutputShape(TfLiteContext* context, + const TfLiteTensor* output_shape, + TfLiteTensor* output) { + // Currently only support int32 for output shape. + if (output_shape->type != kTfLiteInt32) { + context->ReportError(context, "Output shape is %d, not int32.", + output_shape->type); + return kTfLiteError; + } + const int output_dimensions = NumElements(output_shape); + TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(output_dimensions); + for (int i = 0; i < output_dimensions; ++i) { + output_shape_array->data[i] = GetTensorData(output_shape)[i]; + } + + return context->ResizeTensor(context, output, output_shape_array); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4); + + // Currenlty only supports float32. + const TfLiteType data_type = input->type; + TF_LITE_ENSURE(context, data_type == kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, output->type, data_type); + TF_LITE_ENSURE_EQ(context, weights->type, data_type); + + // Ensure that weights and inputs have the same channel dimension. + // Note: TOCO will reorder weights in the following format: OHWI. + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3), + SizeOfDimension(weights, 3)); + + if (!IsConstantTensor(output_shape)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputShape(context, output_shape, output); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const auto* params = + reinterpret_cast(node->builtin_data); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputShape(context, output_shape, output)); + } + + // Get height and width of the output image. + const int width = SizeOfDimension(output, 2); + const int height = SizeOfDimension(output, 1); + const int filter_width = SizeOfDimension(weights, 1); + const int filter_height = SizeOfDimension(weights, 2); + + const int stride_width = params->stride_width; + const int stride_height = params->stride_height; + + const TfLitePaddingValues& padding_size = + ComputePaddingHeightWidth(stride_height, stride_width, 1, height, width, + filter_height, filter_width, params->padding); + + // Currently only support float32. + switch (input->type) { + case kTfLiteFloat32: + optimized_ops::TransposeConv( + GetTensorData(input), GetTensorDims(input), + GetTensorData(weights), GetTensorDims(weights), stride_width, + stride_height, padding_size.width, padding_size.height, + GetTensorData(output), GetTensorDims(output)); + break; + default: + context->ReportError(context, "Type %d, not currently supported.", + input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace transpose_conv + +TfLiteRegistration* Register_TRANSPOSE_CONV() { + static TfLiteRegistration r = {nullptr, nullptr, transpose_conv::Prepare, + transpose_conv::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..55df8971806ed0baae9f5bcaebd24fb8065ec300 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc @@ -0,0 +1,222 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class TransposeConvOpModel : public SingleOpModel { + public: + TransposeConvOpModel(std::initializer_list input_shape, + std::initializer_list filter_shape, Padding padding, + int stride_w, int stride_h) { + output_shape_ = AddInput(TensorType_INT32); + filter_ = AddInput(TensorType_FLOAT32); + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions, + CreateTransposeConvOptions(builder_, padding, stride_w, stride_h) + .Union()); + BuildInterpreter({{4}, filter_shape, input_shape}); + } + + int output_shape() { return output_shape_; } + int filter() { return filter_; } + int input() { return input_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int output_shape_; + int filter_; + int input_; + int output_; +}; + +// Test case: +// output = tf.nn.conv2d_backprop_input( +// tf.constant([ 1, 4, 4, 1 ]), +// tf.constant(np.arange(1, 10), shape=[ 3, 3, 1, 1 ], dtype=tf.float32), +// tf.constant(np.arange(1, 17), shape=[ 1, 4, 4, 1 ], dtype=tf.float32), +// [1, 1, 1, 1 ], +// "SAME") +TEST(TransposeConvOpModelTest, SimpleTest) { + TransposeConvOpModel m({1, 4, 4, 1}, {1, 3, 3, 1}, Padding_SAME, 1, 1); + m.PopulateTensor(m.output_shape(), {1, 4, 4, 1}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9}); + m.PopulateTensor( + m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({29, 62, 83, 75, 99, 192, 237, 198, 207, 372, + 417, 330, 263, 446, 485, 365})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +// Test case: +// filter = tf.constant(np.arange(1, 19), +// shape=[ 3, 3, 1, 2 ], +// dtype=tf.float32) +// output = tf.nn.conv2d_backprop_input( +// tf.constant([ 1, 4, 4, 1 ]), +// filter, +// tf.constant(np.arange(1, 33), shape=[ 1, 4, 4, 2 ], dtype=tf.float32), +// [1, 1, 1, 1 ], +// "SAME") +// And filter value is derived by: +// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1]) +TEST(TransposeConvOpModelTest, TwoFiltersTest) { + TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_SAME, 1, 1); + m.PopulateTensor(m.output_shape(), {1, 4, 4, 1}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18}); + m.PopulateTensor( + m.input(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({184, 412, 568, 528, 678, 1347, 1689, 1434, 1494, + 2715, 3057, 2442, 1968, 3352, 3652, 2760})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +// Test case: +// filter = tf.constant(np.arange(1, 19), +// shape=[ 3, 3, 1, 2 ], +// dtype=tf.float32) +// output = tf.nn.conv2d_backprop_input( +// tf.constant([ 1, 6, 6, 1 ]), +// filter, +// tf.constant(np.arange(1, 33), shape=[ 1, 4, 4, 2 ], dtype=tf.float32), +// [1, 1, 1, 1 ], +// "VALID") +// And filter value is derived by: +// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18]) +TEST(TransposeConvOpModelTest, PaddingValidTest) { + TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_VALID, 1, 1); + m.PopulateTensor(m.output_shape(), {1, 6, 6, 1}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18}); + m.PopulateTensor( + m.input(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({5, 22, 59, 101, 114, 83, 52, 184, 412, + 568, 528, 344, 237, 678, 1347, 1689, 1434, 879, + 597, 1494, 2715, 3057, 2442, 1431, 856, 1968, 3352, + 3652, 2760, 1548, 689, 1534, 2543, 2729, 2010, 1103})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 6, 6, 1})); +} + +// Test case: +// filter = tf.constant(np.arange(1, 10), +// shape=[ 3, 3, 1, 1 ], +// dtype=tf.float32) +// output = tf.nn.conv2d_backprop_input( +// tf.constant([ 1, 5, 5, 1 ]), +// filter, +// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32), +// [1, 2, 2, 1 ], +// "VALID") +TEST(TransposeConvOpModelTest, StrideValidTest) { + TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 1}, Padding_VALID, 2, 2); + m.PopulateTensor(m.output_shape(), {1, 5, 5, 1}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9}); + m.PopulateTensor(m.input(), {1, 2, 3, 4}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 5, 4, 6, 4, 5, 14, 10, 12, 10, 14, 36, + 24, 30, 12, 15, 34, 20, 24, 21, 24, 55, 32, 36})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 1})); +} + +// Test case: +// filter = tf.constant(np.arange(1, 19), +// shape=[ 3, 3, 2, 1 ], +// dtype=tf.float32) +// output = tf.nn.conv2d_backprop_input( +// tf.constant([ 1, 5, 5, 2 ]), +// filter, +// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32), +// [1, 2, 2, 1 ], +// "VALID") +TEST(TransposeConvOpModelTest, MultiChannelTest) { + TransposeConvOpModel m({1, 2, 2, 1}, {2, 3, 3, 1}, Padding_VALID, 2, 2); + m.PopulateTensor(m.output_shape(), {1, 5, 5, 2}); + m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, + 8, 10, 12, 14, 16, 18}); + m.PopulateTensor(m.input(), {1, 2, 3, 4}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 7, 10, 6, 8, 10, 12, 7, 8, 9, + 10, 25, 28, 18, 20, 22, 24, 16, 20, 24, 28, 62, 72, + 42, 48, 54, 60, 21, 24, 27, 30, 61, 68, 36, 40, 44, + 48, 39, 42, 45, 48, 103, 110, 60, 64, 68, 72})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 2})); +} + +// Test case: +// filter = tf.constant(np.random.randint(1, 10, size=9), +// shape=[ 3, 3, 1, 1 ], +// dtype=tf.float32) +// output = tf.nn.conv2d_backprop_input( +// tf.constant([ 1, 3, 4, 1 ]), +// filter, +// tf.constant([323, 521], shape=[ 1, 1, 2, 1], dtype=tf.float32), +// [1, 3, 3, 1 ], +// "SAME") +// And filter value is derived by: +// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[-1]) +TEST(TransposeConvOpModelTest, AccuracyTest) { + TransposeConvOpModel m({1, 1, 2, 1}, {1, 3, 3, 1}, Padding_SAME, 3, 3); + m.PopulateTensor(m.output_shape(), {1, 3, 4, 1}); + m.PopulateTensor(m.filter(), {9, 5, 6, 9, 8, 5, 3, 1, 4}); + m.PopulateTensor(m.input(), {323, 521}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {1615., 1938., 4689., 2605., 2584., 1615., + 4689., 4168., 323., 1292., 1563., 521.}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 4, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index 5987bf68b5a73eaf39567bd65a0508839475505f..1c28123a24edd9886476bf8e9ea3ba4c692baa2b 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -92,7 +92,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE(context, params->cell_clip >= 0); TF_LITE_ENSURE(context, params->proj_clip >= 0); - TfLiteTensor* input_to_input_weights = + const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); if (input_to_input_weights) { TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); @@ -100,19 +100,19 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); } - TfLiteTensor* input_to_forget_weights = + const TfLiteTensor* input_to_forget_weights = GetInput(context, node, kInputToForgetWeightsTensor); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); - TfLiteTensor* input_to_cell_weights = + const TfLiteTensor* input_to_cell_weights = GetInput(context, node, kInputToCellWeightsTensor); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); - TfLiteTensor* recurrent_to_input_weights = + const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); if (recurrent_to_input_weights) { TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); @@ -122,7 +122,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, n_output); } - TfLiteTensor* recurrent_to_forget_weights = + const TfLiteTensor* recurrent_to_forget_weights = GetInput(context, node, kRecurrentToForgetWeightsTensor); TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], @@ -130,7 +130,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], n_output); - TfLiteTensor* recurrent_to_cell_weights = + const TfLiteTensor* recurrent_to_cell_weights = GetInput(context, node, kRecurrentToCellWeightsTensor); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); @@ -146,21 +146,21 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, (recurrent_to_input_weights == nullptr)); TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); - TfLiteTensor* cell_to_input_weights = + const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); if (cell_to_input_weights) { TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); } - TfLiteTensor* cell_to_forget_weights = + const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); if (cell_to_forget_weights) { TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); } - TfLiteTensor* cell_to_output_weights = + const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); if (cell_to_output_weights) { TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); @@ -179,7 +179,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); // Make sure the input gate bias is present only when not a CIFG-LSTM. - TfLiteTensor* input_gate_bias = + const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(context, node, kInputGateBiasTensor); if (use_cifg) { TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); @@ -188,21 +188,21 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); } - TfLiteTensor* forget_gate_bias = + const TfLiteTensor* forget_gate_bias = GetInput(context, node, kForgetGateBiasTensor); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); - TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); - TfLiteTensor* output_gate_bias = + const TfLiteTensor* output_gate_bias = GetInput(context, node, kOutputGateBiasTensor); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); - TfLiteTensor* projection_weights = + const TfLiteTensor* projection_weights = GetOptionalInputTensor(context, node, kProjectionWeightsTensor); if (projection_weights) { TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); @@ -210,7 +210,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); } - TfLiteTensor* projection_bias = + const TfLiteTensor* projection_bias = GetOptionalInputTensor(context, node, kProjectionBiasTensor); if (projection_bias) { TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); @@ -241,19 +241,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TF_LITE_ENSURE(context, input->dims->size > 1); const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; - TfLiteTensor* input_to_output_weights = + const TfLiteTensor* input_to_output_weights = GetInput(context, node, kInputToOutputWeightsTensor); const int n_cell = input_to_output_weights->dims->data[0]; TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); - TfLiteTensor* recurrent_to_output_weights = + const TfLiteTensor* recurrent_to_output_weights = GetInput(context, node, kRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], @@ -300,7 +300,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_state->allocation_type = kTfLiteArenaRwPersistent; cell_state->allocation_type = kTfLiteArenaRwPersistent; - TfLiteTensor* input_to_input_weights = + const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); const bool use_cifg = (input_to_input_weights == nullptr); if (use_cifg) { @@ -324,44 +324,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // The LSTM Op engine. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* input_to_input_weights = + const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - TfLiteTensor* input_to_forget_weights = + const TfLiteTensor* input_to_forget_weights = GetInput(context, node, kInputToForgetWeightsTensor); - TfLiteTensor* input_to_cell_weights = + const TfLiteTensor* input_to_cell_weights = GetInput(context, node, kInputToCellWeightsTensor); - TfLiteTensor* input_to_output_weights = + const TfLiteTensor* input_to_output_weights = GetInput(context, node, kInputToOutputWeightsTensor); - TfLiteTensor* recurrent_to_input_weights = + const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - TfLiteTensor* recurrent_to_forget_weights = + const TfLiteTensor* recurrent_to_forget_weights = GetInput(context, node, kRecurrentToForgetWeightsTensor); - TfLiteTensor* recurrent_to_cell_weights = + const TfLiteTensor* recurrent_to_cell_weights = GetInput(context, node, kRecurrentToCellWeightsTensor); - TfLiteTensor* recurrent_to_output_weights = + const TfLiteTensor* recurrent_to_output_weights = GetInput(context, node, kRecurrentToOutputWeightsTensor); - TfLiteTensor* cell_to_input_weights = + const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - TfLiteTensor* cell_to_forget_weights = + const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - TfLiteTensor* cell_to_output_weights = + const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - TfLiteTensor* input_gate_bias = + const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(context, node, kInputGateBiasTensor); - TfLiteTensor* forget_gate_bias = + const TfLiteTensor* forget_gate_bias = GetInput(context, node, kForgetGateBiasTensor); - TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - TfLiteTensor* output_gate_bias = + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = GetInput(context, node, kOutputGateBiasTensor); - TfLiteTensor* projection_weights = + const TfLiteTensor* projection_weights = GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - TfLiteTensor* projection_bias = + const TfLiteTensor* projection_bias = GetOptionalInputTensor(context, node, kProjectionBiasTensor); TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index ac00c37b67dcbe77023a2495a698967ca555b1d5..164a0cbd08d6ce82a413f12ba6b1703087a30aba 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -38,17 +39,26 @@ constexpr int kBiasTensor = 3; constexpr int kHiddenStateTensor = 0; constexpr int kOutputTensor = 1; +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* scratch_tensor_index = new int; + context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; - TfLiteTensor* input_weights = - &context->tensors[node->inputs->data[kWeightsTensor]]; - TfLiteTensor* recurrent_weights = - &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; - TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* recurrent_weights = + GetInput(context, node, kRecurrentWeightsTensor); + const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); // Check all the parameters of tensor match within themselves and match the // input configuration. @@ -63,10 +73,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type); - TfLiteTensor* hidden_state = - &context->tensors[node->outputs->data[kHiddenStateTensor]]; - TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Resize state. TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); @@ -86,22 +97,54 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size_array)); + // Allocate temporary tensors to store quantized values of input and + // hidden_state tensors. + if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) { + int* scratch_tensor_index = reinterpret_cast(node->user_data); + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(3); + node->temporaries->data[0] = *scratch_tensor_index; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[1] = *scratch_tensor_index + 1; + TfLiteTensor* hidden_state_quantized = + GetTemporary(context, node, /*index=*/1); + hidden_state_quantized->type = kTfLiteUInt8; + hidden_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(hidden_state_quantized->dims, + hidden_state->dims)) { + TfLiteIntArray* hidden_state_quantized_size = + TfLiteIntArrayCopy(hidden_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, hidden_state_quantized, + hidden_state_quantized_size)); + } + node->temporaries->data[2] = *scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + } return kTfLiteOk; } -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; - TfLiteTensor* input_weights = - &context->tensors[node->inputs->data[kWeightsTensor]]; - TfLiteTensor* recurrent_weights = - &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; - TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; - TfLiteTensor* hidden_state = - &context->tensors[node->outputs->data[kHiddenStateTensor]]; - TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; - +TfLiteStatus EvalFloat(const TfLiteTensor* input, + const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, + const TfLiteTensor* bias, + const TfLiteSequenceRNNParams* params, + TfLiteTensor* hidden_state, TfLiteTensor* output) { // Initialize the pointer bias. const float* bias_ptr = bias->data.f; @@ -120,7 +163,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (time_major) { // Initialize the pointer to hidden state. float* hidden_state_ptr_batch = hidden_state->data.f; - // Unroll the sequence and use batch batch operations for efficiency. + // Unroll the sequence and use batch operations for efficiency. for (int s = 0; s < max_time; s++) { // Initialize the pointer to input and output. const float* input_ptr_batch = @@ -154,12 +197,116 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias, + const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch, + TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, + TfLiteTensor* hidden_state, TfLiteTensor* output) { + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; + const int num_units = input_weights->dims->data[0]; + const int input_size = input->dims->data[2]; + + // Initialize the pointer bias. + const float* bias_ptr = bias->data.f; + // Initialize input_weights and recurrent_weights. + const int8_t* input_weights_ptr = + reinterpret_cast(input_weights->data.uint8); + const int8_t* recurrent_weights_ptr = + reinterpret_cast(recurrent_weights->data.uint8); + // Get the scale of the quantized weights. + float input_weights_scale = input_weights->params.scale; + float recurrent_weights_scale = recurrent_weights->params.scale; + // Initialize temporary storage for quantized values. + int8_t* quantized_input_ptr = + reinterpret_cast(input_scratch->data.uint8); + int8_t* quantized_hidden_state_ptr = + reinterpret_cast(hidden_state_scratch->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + + if (time_major) { + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f; + // Unroll the sequence and use batch operations for efficiency. + for (int s = 0; s < max_time; s++) { + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size; + float* output_ptr_batch = output->data.f + s * num_units * batch_size; + + kernel_utils::RnnBatchStep( + input_ptr_batch, input_weights_ptr, input_weights_scale, + recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, + num_units, batch_size, params->activation, quantized_input_ptr, + quantized_hidden_state_ptr, scaling_factors_ptr, + hidden_state_ptr_batch, output_ptr_batch); + } + } else { + // For each batch + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + for (int s = 0; s < max_time; s++) { + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + output->data.f + b * num_units * max_time + s * num_units; + + kernel_utils::RnnBatchStep( + input_ptr_batch, input_weights_ptr, input_weights_scale, + recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, + input_size, num_units, /*batch_size=*/1, params->activation, + quantized_input_ptr, quantized_hidden_state_ptr, + scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch); + } + } + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* recurrent_weights = + GetInput(context, node, kRecurrentWeightsTensor); + const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); + TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input_weights->type) { + case kTfLiteFloat32: + return EvalFloat(input, input_weights, recurrent_weights, bias, params, + hidden_state, output); + case kTfLiteUInt8: { + // TODO(mirkov): implement eval with quantized inputs as well. + TfLiteTensor* input_quantized = GetTemporary(context, node, 0); + TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); + TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + return EvalHybrid(input, input_weights, recurrent_weights, bias, params, + input_quantized, hidden_state_quantized, + scaling_factors, hidden_state, output); + } + default: + context->ReportError(context, "Type %d not currently supported.", + input_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} + } // namespace unidirectional_sequence_rnn TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN() { - static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, - unidirectional_sequence_rnn::Prepare, - unidirectional_sequence_rnn::Eval}; + static TfLiteRegistration r = { + unidirectional_sequence_rnn::Init, unidirectional_sequence_rnn::Free, + unidirectional_sequence_rnn::Prepare, unidirectional_sequence_rnn::Eval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc index 7e32969763b59620dc3534708f965750680002d2..0adab837b07a6d3bd5d7edd267916cd8e1bb75b2 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -122,17 +122,66 @@ static float rnn_golden_output[] = { 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, 0.628881, 3.58099, 1.49974, 0}; +static std::initializer_list rnn_weights = { + 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}; + +static std::initializer_list rnn_recurrent_weights = { + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}; + +static std::initializer_list rnn_bias = { + 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568, + -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, + 0.37197268, 0.61957061, 0.3956964, -0.37609905}; + class UnidirectionalRNNOpModel : public SingleOpModel { public: - UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size, - bool time_major) + UnidirectionalRNNOpModel( + int batches, int sequence_len, int units, int size, bool time_major, + const TensorType& weights = TensorType_FLOAT32, + const TensorType& recurrent_weights = TensorType_FLOAT32) : batches_(batches), sequence_len_(sequence_len), units_(units), input_size_(size) { input_ = AddInput(TensorType_FLOAT32); - weights_ = AddInput(TensorType_FLOAT32); - recurrent_weights_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(weights); + recurrent_weights_ = AddInput(recurrent_weights); bias_ = AddInput(TensorType_FLOAT32); hidden_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -187,7 +236,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel { int num_batches() { return batches_; } int sequence_len() { return sequence_len_; } - private: + protected: int input_; int weights_; int recurrent_weights_; @@ -201,58 +250,31 @@ class UnidirectionalRNNOpModel : public SingleOpModel { int input_size_; }; -// TODO(mirkov): add another test which directly compares to TF once TOCO -// supports the conversion from dynamic_rnn with BasicRNNCell. -TEST(FullyConnectedOpTest, BlackBoxTest) { +// The hybrid model has quantized weights and recurrent_weights. +class HybridUnidirectionalRNNOpModel : public UnidirectionalRNNOpModel { + public: + HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units, + int size, bool time_major) + : UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major, + TensorType_UINT8, TensorType_UINT8) {} + + void SetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_weights_, f); + } +}; + +TEST(UnidirectionalRNNOpTest, BlackBoxTest) { UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*units=*/16, /*size=*/8, /*time_major=*/false); - rnn.SetWeights( - {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, - 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, - 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, - -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, - -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, - -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, - -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, - 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, - 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, - 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, - -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, - 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, - -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, - -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, - 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, - 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, - 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, - -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, - 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, - 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, - -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, - 0.277308, 0.415818}); - - rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, - -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, - 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, - -0.37609905}); - - rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1}); - + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.ResetHiddenState(); + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); float* batch_start = rnn_input; float* batch_end = batch_start + input_sequence_size; @@ -270,56 +292,42 @@ TEST(FullyConnectedOpTest, BlackBoxTest) { EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); } -TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) { - UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, - /*units=*/16, /*size=*/8, /*time_major=*/true); - rnn.SetWeights( - {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, - 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, - 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, - -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, - -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, - -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, - -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, - 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, - 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, - 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, - -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, - 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, - -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, - -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, - 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, - 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, - 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, - -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, - 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, - 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, - -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, - 0.277308, 0.415818}); - - rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, - -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, - 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, - -0.37609905}); - - rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1}); +TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTest) { + HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, + /*time_major=*/false); + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); + rnn.ResetHiddenState(); + + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + float* batch_start = rnn_input; + float* batch_end = batch_start + input_sequence_size; + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(input_sequence_size, batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output; + float* golden_end = golden_start + rnn.num_units() * rnn.sequence_len(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear( + expected, /*max_abs_error=*/0.013))); +} +TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) { + UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, + /*time_major=*/true); + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.ResetHiddenState(); + for (int i = 0; i < rnn.sequence_len(); i++) { float* batch_start = rnn_input + i * rnn.input_size(); float* batch_end = batch_start + rnn.input_size(); @@ -341,6 +349,37 @@ TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) { EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); } +TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTest) { + HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, + /*time_major=*/true); + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); + rnn.ResetHiddenState(); + + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + // The two batches are identical. + rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + } + + rnn.Invoke(); + + std::vector expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_batch_start = rnn_golden_output + i * rnn.num_units(); + float* golden_batch_end = golden_batch_start + rnn.num_units(); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + } + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear( + expected, /*max_abs_error=*/0.013))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index e89036ce730dd7cc507559ad908fb01513aba708..039f32b38eb29068b223dd63355c66295301beba 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -184,8 +184,10 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { TfLiteStatus status = kTfLiteOk; auto opcodes = model_->operator_codes(); for (const OperatorCode* opcode : *opcodes) { - TfLiteRegistration* registration = nullptr; + const TfLiteRegistration* registration = nullptr; auto builtin_code = opcode->builtin_code(); + int version = opcode->version(); + if (builtin_code > BuiltinOperator_MAX || builtin_code < BuiltinOperator_MIN) { error_reporter_->Report( @@ -194,8 +196,7 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { builtin_code); status = kTfLiteError; } else if (builtin_code != BuiltinOperator_CUSTOM) { - flatbuffer_op_index_to_registration_types_.push_back(builtin_code); - registration = op_resolver_.FindOp(builtin_code); + registration = op_resolver_.FindOp(builtin_code, version); if (registration == nullptr) { error_reporter_->Report("Didn't find op for builtin opcode '%s'\n", EnumNameBuiltinOperator(builtin_code)); @@ -207,11 +208,13 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { status = kTfLiteError; } else { const char* name = opcode->custom_code()->c_str(); - registration = op_resolver_.FindOp(name); + registration = op_resolver_.FindOp(name, version); flatbuffer_op_index_to_registration_types_.push_back( BuiltinOperator_CUSTOM); if (registration == nullptr) { - error_reporter_->Report("Didn't find custom op for name '%s'\n", name); + error_reporter_->Report( + "Didn't find custom op for name '%s' with version %d\n", name, + version); status = kTfLiteError; } } @@ -319,12 +322,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = nullptr; switch (op_type) { - case BuiltinOperator_CALL: - // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are - // ok for now, since there is no call implementation either. - break; - case BuiltinOperator_CUSTOM: - break; case BuiltinOperator_CONV_2D: { TfLiteConvParams* params = MallocPOD(); if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { @@ -333,26 +330,13 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->stride_height = conv_params->stride_h(); params->activation = parse_activation(conv_params->fused_activation_function()); + params->dilation_width_factor = conv_params->dilation_w_factor(); params->dilation_height_factor = conv_params->dilation_h_factor(); } *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_TANH: - case BuiltinOperator_LOGISTIC: - case BuiltinOperator_RELU: - case BuiltinOperator_RELU_N1_TO_1: - case BuiltinOperator_RELU6: - case BuiltinOperator_CONCAT_EMBEDDINGS: - case BuiltinOperator_EXP: - case BuiltinOperator_TOPK_V2: - case BuiltinOperator_LOG_SOFTMAX: - case BuiltinOperator_DEQUANTIZE: - case BuiltinOperator_PRELU: - case BuiltinOperator_FLOOR: - case BuiltinOperator_NEG: - break; case BuiltinOperator_CAST: { TfLiteCastParams* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_CastOptions()) { @@ -440,9 +424,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_EMBEDDING_LOOKUP: - // no-op. - break; case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { TfLiteEmbeddingLookupSparseParams* params = MallocPOD(); @@ -553,6 +534,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, parse_activation(lstm_params->fused_activation_function()); params->cell_clip = lstm_params->cell_clip(); params->proj_clip = lstm_params->proj_clip(); + switch (lstm_params->kernel_type()) { + case LSTMKernelType_FULL: + params->kernel_type = kTfLiteLSTMFullKernel; + break; + case LSTMKernelType_BASIC: + params->kernel_type = kTfLiteLSTMBasicKernel; + break; + } } *builtin_data = reinterpret_cast(params); break; @@ -566,12 +555,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_PAD: { - break; - } - case BuiltinOperator_PADV2: { - break; - } case BuiltinOperator_RESHAPE: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { @@ -611,15 +594,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_SPACE_TO_BATCH_ND: { - break; - } - case BuiltinOperator_BATCH_TO_SPACE_ND: { - break; - } - case BuiltinOperator_TRANSPOSE: { - break; - } case BuiltinOperator_MEAN: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_MeanOptions()) { @@ -659,10 +633,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_MAXIMUM: - case BuiltinOperator_MINIMUM: { - break; - } case BuiltinOperator_ARG_MAX: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { @@ -672,11 +642,26 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_GREATER: - case BuiltinOperator_GREATER_EQUAL: - case BuiltinOperator_LESS: - case BuiltinOperator_LESS_EQUAL: - case BuiltinOperator_SELECT: { + case BuiltinOperator_TRANSPOSE_CONV: { + TfLiteTransposeConvParams* params = + MallocPOD(); + if (auto* transpose_conv_params = + op->builtin_options_as_TransposeConvOptions()) { + params->padding = parse_padding(transpose_conv_params->padding()); + params->stride_width = transpose_conv_params->stride_w(); + params->stride_height = transpose_conv_params->stride_h(); + } + *builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SPARSE_TO_DENSE: { + TfLiteSparseToDenseParams* params = + MallocPOD(); + if (auto* sparse_to_dense_params = + op->builtin_options_as_SparseToDenseOptions()) { + params->validate_indices = sparse_to_dense_params->validate_indices(); + } + *builtin_data = reinterpret_cast(params); break; } case BuiltinOperator_DELEGATE: { @@ -684,6 +669,46 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, error_reporter->Report("DELEGATE op shouldn't exist in model."); return kTfLiteError; } + + // Below are the ops with no builtin_data strcture. + case BuiltinOperator_BATCH_TO_SPACE_ND: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + case BuiltinOperator_CALL: + case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_CUSTOM: + case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_EMBEDDING_LOOKUP: + case BuiltinOperator_EQUAL: + case BuiltinOperator_EXP: + case BuiltinOperator_EXPAND_DIMS: + case BuiltinOperator_FLOOR: + case BuiltinOperator_GREATER: + case BuiltinOperator_GREATER_EQUAL: + case BuiltinOperator_LESS: + case BuiltinOperator_LESS_EQUAL: + case BuiltinOperator_LOG: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_MINIMUM: + case BuiltinOperator_NEG: + case BuiltinOperator_NOT_EQUAL: + case BuiltinOperator_PAD: + case BuiltinOperator_PADV2: + case BuiltinOperator_PRELU: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU6: + case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_SELECT: + case BuiltinOperator_SIN: + case BuiltinOperator_SLICE: + case BuiltinOperator_SPACE_TO_BATCH_ND: + case BuiltinOperator_TANH: + case BuiltinOperator_TILE: + case BuiltinOperator_TOPK_V2: + case BuiltinOperator_TRANSPOSE: + break; } return kTfLiteOk; } @@ -703,27 +728,30 @@ TfLiteStatus InterpreterBuilder::ParseNodes( status = kTfLiteError; continue; } - const TfLiteRegistration* reg = + + const TfLiteRegistration* registration = flatbuffer_op_index_to_registration_[op->opcode_index()]; - if (reg == nullptr) { + if (registration == nullptr) { error_reporter_->Report("Skipping op for opcode_index %d\n", index); status = kTfLiteError; continue; } - auto op_type = - flatbuffer_op_index_to_registration_types_[op->opcode_index()]; + BuiltinOperator op_type = + static_cast(registration->builtin_code); + if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { error_reporter_->Report( "Found builtin operator %s with custom options.\n", EnumNameBuiltinOperator(op_type)); } + if (op->custom_options()) { interpreter->AddNodeWithParameters( FlatBufferIntArrayToVector(op->inputs()), FlatBufferIntArrayToVector(op->outputs()), reinterpret_cast(op->custom_options()->data()), - op->custom_options()->size(), nullptr, reg); + op->custom_options()->size(), nullptr, registration); } else { void* builtin_data = nullptr; TF_LITE_ENSURE_STATUS( @@ -731,7 +759,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes( interpreter->AddNodeWithParameters( FlatBufferIntArrayToVector(op->inputs()), FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, builtin_data, - reg); + registration); } } diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 5a55b031a8c28085e02782608eb820a3cfe78dde..3946b490417104f620ecb55bb22d4ef99fd33bb7 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -37,6 +37,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/op_resolver.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" namespace tflite { @@ -131,18 +132,6 @@ class FlatBufferModel { Allocation* allocation_ = nullptr; }; -// Abstract interface that returns TfLiteRegistrations given op codes or custom -// op names. This is the mechanism that ops being referenced in the flatbuffer -// model are mapped to executable function pointers (TfLiteRegistrations). -class OpResolver { - public: - // Finds the op registration for a builtin operator by enum code. - virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; - // Finds the op registration of a custom operator by op name. - virtual TfLiteRegistration* FindOp(const char* op) const = 0; - virtual ~OpResolver() {} -}; - // Build an interpreter capable of interpreting `model`. // // model: a scoped model whose lifetime must be at least as long as @@ -187,7 +176,7 @@ class InterpreterBuilder { const OpResolver& op_resolver_; ErrorReporter* error_reporter_; - std::vector flatbuffer_op_index_to_registration_; + std::vector flatbuffer_op_index_to_registration_; std::vector flatbuffer_op_index_to_registration_types_; const Allocation* allocation_ = nullptr; }; diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index ae6c1ece18963f11f48a6f07bea4065ce39687e0..15bae21a411c1241cf71ab4d3f0e0289eaac8ef3 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -55,11 +55,12 @@ class TrivialResolver : public OpResolver { explicit TrivialResolver(TfLiteRegistration* constant_return = nullptr) : constant_return_(constant_return) {} // Find the op registration of a custom operator by op name. - TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override { + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override { return constant_return_; } // Find the op registration of a custom operator by op name. - TfLiteRegistration* FindOp(const char* op) const override { + const TfLiteRegistration* FindOp(const char* op, int version) const override { return constant_return_; } diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD index a82d1f2eb673b9b7211581f5a9f9febc140d4d1e..8b5fa240ac31d9ee61879c42aee3c5d449ae60db 100644 --- a/tensorflow/contrib/lite/models/smartreply/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -22,7 +22,6 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/tools:mutable_op_resolver", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", "@farmhash_archive//:farmhash", @@ -39,7 +38,6 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/tools:mutable_op_resolver", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", ], diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc index f97a6486d6c11cf0184622f515fe5b1e096c6257..29c8ad2286d705ea60fcd258e7283f6e1c3b70b8 100644 --- a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc +++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc @@ -61,7 +61,7 @@ bool IsValidNgram(const tflite::StringRef& strref) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* outputSize1 = TfLiteIntArrayCreate(1); TfLiteIntArray* outputSize2 = TfLiteIntArrayCreate(1); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); int dim = input->dims->data[0]; if (dim == 0) { // TFLite non-string output should have size greater than 0. @@ -76,7 +76,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); int num_strings = tflite::GetStringCount(input); TfLiteTensor* label = GetOutput(context, node, 0); TfLiteTensor* weight = GetOutput(context, node, 1); diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/contrib/lite/models/smartreply/predictor.cc index 6da5cc8eecc0920850f666b0992c4d9598c55b6c..5d6c47dce8d90192d35a3a51fe6d0beb11f3b23f 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor.cc +++ b/tensorflow/contrib/lite/models/smartreply/predictor.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" #include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); @@ -104,11 +104,11 @@ void GetSegmentPredictions( }); // Add backoff response. - for (const string& backoff : config.backoff_responses) { + for (const auto& backoff : config.backoff_responses) { if (predictor_responses->size() >= config.num_response) { break; } - predictor_responses->push_back({backoff, config.backoff_confidence}); + predictor_responses->emplace_back(backoff, config.backoff_confidence); } } diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc index e6c8d966f1aff5a867f9469f8fcdec526df84763..c7e08814fdf502f1ecfea60af3385fc7aa6055fa 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc +++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc @@ -35,8 +35,8 @@ const char kModelName[] = "smartreply_ondevice_model.bin"; const char kSamples[] = "smartreply_samples.tsv"; string TestDataPath() { - return string(StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", - "contrib/lite/models/testdata/")); + return string(absl::StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", + "contrib/lite/models/testdata/")); } MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") { @@ -55,7 +55,7 @@ class PredictorTest : public ::testing::Test { protected: PredictorTest() { model_ = tflite::FlatBufferModel::BuildFromFile( - StrCat(TestDataPath(), "/", kModelName).c_str()); + absl::StrCat(TestDataPath(), "/", kModelName).c_str()); CHECK(model_); } ~PredictorTest() override {} @@ -121,7 +121,7 @@ TEST_F(PredictorTest, BatchTest) { int total_triggers = 0; string line; - std::ifstream fin(StrCat(TestDataPath(), "/", kSamples)); + std::ifstream fin(absl::StrCat(TestDataPath(), "/", kSamples)); while (std::getline(fin, line)) { const std::vector fields = absl::StrSplit(line, '\t'); if (fields.empty()) { diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h index 4a648e42837fbf6b7326c315be202ae0a80a47ca..becd1f615f04a806cba9c494323285c004ec41df 100644 --- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h +++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h @@ -65,7 +65,8 @@ inline bool NNAPIExists() { return nnapi_is_available; } -// nn api types +// NN api types based on NNAPI header file +// https://developer.android.com/ndk/reference/group/neural-networks /** * Operand types. @@ -77,31 +78,11 @@ inline bool NNAPIExists() { * ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, and ANEURALNETWORKS_INT32. */ enum { - /** The following entries are used to declare scalars. */ - - /** A 32 bit floating point scalar value. */ ANEURALNETWORKS_FLOAT32 = 0, - /** A signed 32 bit integer scalar value. */ ANEURALNETWORKS_INT32 = 1, - /** An unsigned 32 bit integer scalar value. */ ANEURALNETWORKS_UINT32 = 2, - - /** The following entries are used to declare tensors. */ - - /** A tensor of 32 bit floating point values. */ ANEURALNETWORKS_TENSOR_FLOAT32 = 3, - /** A tensor of 32 bit integer values. */ ANEURALNETWORKS_TENSOR_INT32 = 4, - /** A tensor of 8 bit integers that represent real numbers. - * - * Attached to this tensor are two numbers that can be used to convert - * the 8 bit integer to the real value and vice versa. These two numbers are: - * - scale: a 32 bit floating point value - * - zero_value: an 32 bit integer - * - * The formula is: - * real_value = (integer_value - zero_value) * scale. - */ ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5, }; @@ -111,968 +92,44 @@ enum { * The type of operations that can be added to a model. */ enum { - /** Adds two tensors, element-wise. - * - * Takes two input tensors of identical type and compatible dimensions. The - * output is the sum of both input tensors, optionally modified by an - * activation function. - * - * Two dimensions are compatible when: - * 1. they are equal, or - * 2. one of them is 1 - * - * The size of the output is the maximum size along each dimension of the - * input operands. It starts with the trailing dimensions, and works its way - * forward. - * - * Example: - * - * input1.dimension = {4, 1, 2} - * input2.dimension = {5, 4, 3, 1} - * output.dimension = {5, 4, 3, 2} - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * - * Supported tensor rank: up to 4 - * - * Inputs: - * * 0: A tensor. - * * 1: A tensor of the same type, and compatible dimensions as input0. - * * 2: An INT32 value, and has to be one of the {@link FuseCode} values. - * Specifies the activation to invoke on the result of each addition. - * - * Outputs: - * * 0: The sum, a tensor of the same type as input0. - */ ANEURALNETWORKS_ADD = 0, - /** Performs a 2-D average pooling operation. - * - * The output dimensions are functions of the filter dimensions, stride, and - * padding. - * - * The values in the output tensor are computed as: - * - * output[batch, row, col, channel] = - * sum_{i, j}(input[batch, row + i, col + j, channel]) / sum(1) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: 4, with "NHWC" data layout. - * - * Inputs: - * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the - * input. - * * 1: An INT32 value, specifying the padding on the left, in the ‘width’ - * dimension. - * * 2: An INT32 value, specifying the padding on the right,in the ‘width’ - * dimension. - * * 3: An INT32 value, specifying the padding on the top, in the ‘height’ - * dimension. - * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’ - * dimension. - * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension. - * * 6: An INT32 value, specifying the output stride in the ‘height’ - * dimension. - * * 7: An INT32 value, specifying the filter width. - * * 8: An INT32 value, specifying the filter height. - * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. - * Specifies the activation to invoke on the result of each addition. - * - * Outputs: - * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, - * depth]. - */ ANEURALNETWORKS_AVERAGE_POOL_2D = 1, - /** Concatenates the input tensors along the given dimension. - * - * The input tensors must have identical type and the same dimensions except - * the dimension along the concatenation axis. - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: up to 4 - * - * Inputs: - * 0 ~ n: The list on n input tensors, of shape [D0, D1, ..., Daxis(i), ..., - * Dm] n+1: An INT32 value, specifying the concatenation axis. n+2: An INT32 - * value, and has to be one of the {@link FuseCode} values. Specifies the - * activation to invoke on the result of each addition. - * - * Outputs: - * * 0: The output, a tensor of the same type as the input tensors. - * The output shape is [D0, D1, ..., sum(Daxis(i)), ..., Dm]. - */ ANEURALNETWORKS_CONCATENATION = 2, - /** Performs an 2-D convolution operation. - * - * The CONV_2D op sweeps a 2-D filter that can mix channels together over a - * batch of images, applying the filter to each window of each image of the - * appropriate size. - * - * The output dimensions are functions of the filter dimensions, stride, and - * padding. - * - * The values in the output tensor are computed as: - * - * output[batch, row, col, channel] = - * sum_{i, j} ( - * input[batch, row + i, col + j, k] * - * filter[channel, row + i, col + j, k] + - * bias[channel] - * ) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: 4, with "NHWC" data layout. - * - * Inputs: - * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying - * the input. - * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width, - * depth_in], specifying the filter. - * * 2: A 1-D tensor, of shape [depth_out], specifying the bias. - * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the - * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input - * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should - * be of {@link ANEURALNETWORKS_TENSOR_INT32}. - * * 3: An INT32 value, specifying the padding on the left, in the ‘width’ - * dimension. - * * 4: An INT32 value, specifying the padding on the right,in the ‘width’ - * dimension. - * * 5: An INT32 value, specifying the padding on the top, in the ‘height’ - * dimension. - * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’ - * dimension. - * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension. - * * 8: An INT32 value, specifying the output stride in the ‘height’ - * dimension. - * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. - * Specifies the activation to invoke on the result of each addition. - * - * Outputs: - * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, - * depth_out]. - */ ANEURALNETWORKS_CONV_2D = 3, - /** Performs a depthwise 2-D convolution operation. - * - * Given an input tensor of shape [batches, height, width, depth_in] and a - * filter tensor of shape [depth_out, filter_height, filter_width, depth_in] - * containing in_channels convolutional filters of depth 1, DEPTHWISE_CONV - * applies a different filter to each input channel (expanding from 1 channel - * to channel_multiplier channels for each), then concatenates the results - * together. - * - * The output has depth_out = depth_in * depth_multiplier channels. - * The output dimensions are functions of the filter dimensions, stride, and - * padding. - * - * The values in the output tensor are computed as: - * - * output[b, i, j, k * channel_multiplier + q] = - * sum_{di, dj} ( - * input[b, strides[1] * i + di, strides[2] * j + dj, k] * - * filter[di, dj, k, q] - * ) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: 4, with "NHWC" data layout. - * - * Inputs: - * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying - * the input. - * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width, - * depth_in], specifying the filter. - * * 2: A 1-D tensor, of shape [depth_out], specifying the bias. - * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the - * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input - * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should - * be of {@link ANEURALNETWORKS_TENSOR_INT32}. - * * 3: An INT32 value, specifying the padding on the left, in the ‘width’ - * dimension. - * * 4: An INT32 value, specifying the padding on the right,in the ‘width’ - * dimension. - * * 5: An INT32 value, specifying the padding on the top, in the ‘height’ - * dimension. - * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’ - * dimension. - * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension. - * * 8: An INT32 value, specifying the output stride in the ‘height’ - * dimension. - * * 9: An INT32 value, specifying the depthwise multiplier. - * * 10: An INT32 value, and has to be one of the {@link FuseCode} values. - * Specifies the activation to invoke on the result of each addition. - * - * Outputs: - * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, - * depth_out]. - */ ANEURALNETWORKS_DEPTHWISE_CONV_2D = 4, - /** Rearranges data from depth into blocks of spatial data. - * - * More specifically, this op outputs a copy of the input tensor where values - * from the depth dimension are moved in spatial blocks to the height and - * width dimensions. The value block_size indicates the input block size and - * how the data is moved. - * - * Chunks of data of size block_size * block_size from depth are rearranged - * into non-overlapping blocks of size block_size x block_size. - * - * The width of the output tensor is input_depth * block_size, whereas the - * height is input_height * block_size. The depth of the input tensor must be - * divisible by block_size * block_size - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: 4, with "NHWC" data layout. - * - * Inputs: - * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying - * the input. - * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and - * block_size * block_size must be a divisor of the input depth. - * - * Outputs: - * * 0: The output 4-D tensor, of shape [batch, height*block_size, - * width*block_size, depth/(block_size*block_size)]. - */ ANEURALNETWORKS_DEPTH_TO_SPACE = 5, - /** Dequantizes the input tensor. - * - * The formula is: - * - * output = (input - zero_value) * scale. - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: up to 4 - * - * Inputs: - * * 0: A tensor of type {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}. - * - * Outputs: - * * 0: The output tensor of same shape as input0, but with type - * {@link ANEURALNETWORKS_TENSOR_FLOAT32}. - */ ANEURALNETWORKS_DEQUANTIZE = 6, - - /** - * Looks up items from a given tensor. - * - * Each item in the output is a raw copy of the corresponding item in - * the input “values”. If the given “lookup” indices are out of bounds, - * the op will fail and an error will be reported. - * - * Inputs: - * * 0: Values. An n-D tensor of any type X (where n >= 2). E.g., if n is 2, - * then the shape would be [lookup_dimension, values_dimension], where - * “lookup_dimension” corresponds to the indexing dimension in the lookup - * table, and “values_dimension” to the contents. - * * 1: Lookups. An 1-D tensor of type T, of shape [lookup_size], where - * “lookup_size” is the number of elements to look for, and each entry - * corresponds to the first dimension of the “values” tensor. - * - * Output: - * * 0: A n-D tensor of type X and the same rank and shape as the “values” - * tensor, except for the first dimension which has size “lookup_size”. - */ ANEURALNETWORKS_EMBEDDING_LOOKUP = 7, - - /** Computes element-wise floor() on the input tensor. - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * - * Supported tensor rank: up to 4 - * - * Inputs: - * * 0: A tensor. - * - * Outputs: - * * 0: The output, a tensor of the same type and dimensions as input0. - */ ANEURALNETWORKS_FLOOR = 8, - /** Denotes a fully (densely) connected layer, which connects all elements in - * the input tensor with each element in the output tensor. - * - * This layer implements the operation: - * - * outputs = activation(inputs * weights’ + bias) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: up to 4. - * - * Inputs: - * * 0: A tensor, specifying the input. If rank is greater than 2, then it - * gets flattened to a 2-D Tensor. The 2-D Tensor is handled as if dimensions - * corresponded to shape [batch_size, input_size], where “batch_size” - * corresponds to the batching dimension, and “input_size” is the size of the - * input. - * * 1: A 2-D tensor, specifying the weights, of shape [num_units, - * input_size], where "num_units" corresponds to the number of output nodes. - * * 2: A 1-D tensor, of shape [num_units], specifying the bias. - * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the - * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input - * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should - * be of {@link ANEURALNETWORKS_TENSOR_INT32}. - * * 3: An INT32 value, and has to be one of the {@link FuseCode} values. - * Specifies the activation to invoke on the result of each addition. - * - * Outputs: - * * 0: The output tensor, of shape [batch_size, num_units]. - */ ANEURALNETWORKS_FULLY_CONNECTED = 9, - - /** - * Looks up values of a hash table with given keys. - * - * Inputs: - * * 0: Lookups. A 1-D int32 tensor with shape [ k ]. - * * 1: Keys. A 1-D int32 tensor with shape [ n ], *MUST* be sorted in - * ascending order. - * * 2: Values. A tensor with shape [ n … ]. - * - * Outputs: - * * 0: Output. A tensor with shape [ k …]. - * * 1: Hits. A uint8 tensor with shape [ k ] indicates whether the lookup - * hits or not. - */ ANEURALNETWORKS_HASHTABLE_LOOKUP = 10, - - /** Applies L2 normalization along the depth dimension. - * - * The values in the output tensor are computed as: - * - * output[batch, row, col, channel] = - * input[batch, row, col, channel] / - * sqrt(sum_{c} pow(input[batch, row, col, c], 2)) - * - * For x with more dimensions, independently normalizes each 1-D slice along - * dimension dim. - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * - * Supported tensor rank: 4, with "NHWC" data layout. - * - * Inputs: - * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the - * input. - * - * Outputs: - * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, - * depth]. - */ ANEURALNETWORKS_L2_NORMALIZATION = 11, - - /** Performs an 2-D L2 pooling operation. - * - * The output dimensions are functions of the filter dimensions, stride, and - * padding. - * - * The values in the output tensor are computed as: - * - * output[batch, row, col, channel] = - * sqrt(sum_{i, j} pow(input[batch, row + i, col + j, channel], 2) / - * sum(1)) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * - * Supported tensor rank: 4, with "NHWC" data layout. - * - * Inputs: - * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the - * input. - * * 1: An INT32 value, specifying the padding on the left, in the ‘width’ - * dimension. - * * 2: An INT32 value, specifying the padding on the right,in the ‘width’ - * dimension. - * * 3: An INT32 value, specifying the padding on the top, in the ‘height’ - * dimension. - * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’ - * dimension. - * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension. - * * 6: An INT32 value, specifying the output stride in the ‘height’ - * dimension. - * * 7: An INT32 value, specifying the filter width. - * * 8: An INT32 value, specifying the filter height. - * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. - * Specifies the activation to invoke on the result of each addition. - * - * Outputs: - * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, - * depth]. - */ ANEURALNETWORKS_L2_POOL_2D = 12, - /** Applies Local Response Normalization along the depth dimension. - * - * The 4-D input tensor is treated as a 3-D array of 1-D vectors (along the - * last dimension), and each vector is normalized independently. Within a - * given vector, each component is divided by the weighted, squared sum of - * inputs within depth_radius. - * - * The output is calculated using this formula: - * - * sqr_sum[a, b, c, d] = - * sum(pow(input[a, b, c, d - depth_radius : d + depth_radius + 1], 2) - * output = input / pow((bias + alpha * sqr_sum), beta) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * - * Supported tensor rank: 4, with "NHWC" data layout. - * - * Inputs: - * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the - * input. - * * 1: An INT32 value, specifying the radius of the normalization window. - * * 2: A FLOAT32 value, specifying the bias, must not be zero. - * * 3: A FLOAT32 value, specifying the scale factor, alpha. - * * 4: A FLOAT32 value, specifying the exponent, beta. - * - * Outputs: - * * 0: The output tensor of same shape as input0. - */ ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION = 13, - /** Computes sigmoid activation on the input tensor element-wise. - * - * The output is calculated using this formula: - * - * output = 1 / (1 + exp(-input)) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: up to 4. - * - * Inputs: - * * 0: A tensor, specifying the input. - * - * Outputs: - * * 0: The output tensor of same shape as input0. - */ ANEURALNETWORKS_LOGISTIC = 14, - - /** - * Projects an input to a bit vector via locality sensitive hashing. - * - * Inputs: - * * 0: Hash functions. Dim.size == 2, DataType: Float. - * Tensor[0].Dim[0]: Number of hash functions. - * Tensor[0].Dim[1]: Number of seeds per hash functions. - * Tensor[0].Dim[1] <= 32 in sparse case. - * - * * 1: Input. Dim.size >= 1, no restriction on DataType. - * * 2: Weight. Optional. Dim.size == 1, DataType: Float. - * If not set, each input element is considered to have the same weight of - * 1.0. - * Tensor[1].Dim[0] == Tensor[2].Dim[0] - * * 3: Type: - * Sparse: Value LSHProjectionType_SPARSE(=1). - * Computed bit vector is considered to be sparse. - * Each output element is an int32 made up of multiple bits computed - * from hash functions. - * - * Dense: Value LSHProjectionType_DENSE(=2). - * Computed bit vector is considered to be dense. Each output element - * represents a bit and can take the value of either 0 or 1. - * - * Outputs: - * * 0: If the projection type is sparse: - * Output.Dim == { Tensor[0].Dim[0] } - * A tensor of int32 that represents hash signatures. - * If the projection type is Dense: - * Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] } - * A flattened tensor that represents projected bit vectors. - */ ANEURALNETWORKS_LSH_PROJECTION = 15, - - /** - * Long short-term memory unit (LSTM) recurrent network layer. - * - * The default non-peephole implementation is based on: - * http://www.bioinf.jku.at/publications/older/2604.pdf - * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural - * Computation, 9(8):1735-1780, 1997. - * - * The peephole implementation is based on: - * https://research.google.com/pubs/archive/43905.pdf - * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory - * recurrent neural network architectures for large scale acoustic modeling." - * INTERSPEECH, 2014. - * - * The coupling of input and forget gate (CIFG) is based on: - * http://arxiv.org/pdf/1503.04069.pdf - * Greff et al. "LSTM: A Search Space Odyssey" - * - * The class has the following independently optional inputs: - * * If input gate (if CIFG): “input_to_forget_weights”, - * “recurrent_to_input_weights”, “cell_to_input_weights”, “input_gate_bias”. - * * If no peephole connections: “cell_to_input_weights”, - * “cell_to_forget_weights”, “cell_to_output_weights”. - * * If no projection layer: “projection_weights” and “projection_bias”. - * * If no projection bias: “projection_bias”. - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * - * Inputs: - * * 0: Input. - * A 2-D tensor of type T, of shape [batch_size, input_size], where - * “batch_size” corresponds to the batching dimension, and “input_size” - * is the size of the input. - * * 1: input_to_input_weights. - * A 2-D tensor of type T, of shape [num_units, input_size], where - * “num_units” corresponds to the number of cell units. - * * 2: input_to_forget_weights. - * A 2-D tensor of type T, of shape [num_units, input_size]. - * * 3: input_to_cell_weights. - * A 2-D tensor of type T, of shape [num_units, input_size]. - * * 4: input_to_output_weights. - * A 2-D tensor of type T, of shape [num_units, input_size]. - * * 5: recurrent_to_input_weights. - * A 2-D tensor of type T, of shape [num_units, output_size], where - * “output_size” corresponds to either the number of cell units (i.e., - * “num_units”), or the second dimension of the “projection_weights”, if - * defined. - * * 6: recurrent_to_forget_weights. - * A 2-D tensor of type T, of shape [num_units, output_size]. - * * 7: recurrent_to_cell_weights. - * A 2-D tensor of type T, of shape [num_units, output_size]. - * * 8: recurrent_to_output_weights. - * A 2-D tensor of type T, of shape [num_units, output_size]. - * * 9: cell_to_input_weights. - * A 1-D tensor of type T, of shape [num_units]. - * * 10:cell_to_forget_weights. - * A 1-D tensor of type T, of shape [num_units]. - * * 11:cell_to_output_weights. - * A 1-D tensor of type T, of shape [num_units]. - * * 12:input_gate_bias. - * A 1-D tensor of type T, of shape [num_units]. - * * 13:forget_gate_bias. - * A 1-D tensor of type T, of shape [num_units]. - * * 14:cell_bias. - * A 1-D tensor of type T, of shape [num_units]. - * * 15:output_gate_bias. - * A 1-D tensor of type T, of shape [num_units]. - * * 16:projection_weights. - * A 2-D tensor of type T, of shape [output_size, num_units]. - * * 17:projection_bias. - * A 1-D tensor of type T, of shape [output_size]. - * - * Parameters: - * * 18:fused_activation_function. - * An (optional) ActivationFunctionType indicating the activation - * function. - * If “NONE” is specified then it results in a linear activation. - * * 19:cell_clip. - * A clipping threshold for the cell state, such that values are bound - * within [-cell_clip, cell_clip]. If set to 0.0 then clipping is - * disabled. - * * 20:proj_clip. - * A clipping threshold for the output from the projection layer, such - * that values are bound within [-proj_clip, proj_clip]. If set to 0.0 - * then clipping is disabled. - * - * Outputs: - * * 0: scratch_buffer. - * A 3-D tensor of type T, of shape [batch_size, num_cell, 4]. - * * 1: output_state. - * A 2-D tensor of type T, of shape [batch_size, output_size]. - * * 2: cell_state. - * A 2-D tensor of type T, of shape [batch_size, num_units]. - * * 3: output. - * A 2-D tensor of type T, of shape [batch_size, output_size]. This is - * effectively the same as the current “output_state” value. - */ ANEURALNETWORKS_LSTM = 16, - - /** Performs an 2-D max pooling operation. - * - * The output dimensions are functions of the filter dimensions, stride, and - * padding. - * - * The values in the output tensor are computed as: - * - * output[batch, row, col, channel] = - * max_{i, j} (input[batch, row + i, col + j, channel]) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: 4, with "NHWC" data layout. - * - * Inputs: - * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the - * input. - * * 1: An INT32 value, specifying the padding on the left, in the ‘width’ - * dimension. - * * 2: An INT32 value, specifying the padding on the right,in the ‘width’ - * dimension. - * * 3: An INT32 value, specifying the padding on the top, in the ‘height’ - * dimension. - * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’ - * dimension. - * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension. - * * 6: An INT32 value, specifying the output stride in the ‘height’ - * dimension. - * * 7: An INT32 value, specifying the filter width. - * * 8: An INT32 value, specifying the filter height. - * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. - * Specifies the activation to invoke on the result of each addition. - * - * Outputs: - * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, - * depth]. - */ ANEURALNETWORKS_MAX_POOL_2D = 17, - - /** Multiplies two tensors, element-wise. - * - * Takes two input tensors of identical type and compatible dimensions. The - * output is the product of both input tensors, optionally modified by an - * activation function. - * - * Two dimensions are compatible when: - * 1. they are equal, or - * 2. one of them is 1 - * - * The size of the resulting output is the maximum size along each dimension - * of the input operands. It starts with the trailing dimensions, and works - * its way forward. - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * - * Supported tensor rank: up to 4 - * - * Inputs: - * * 0: A tensor. - * * 1: A tensor of the same type, and compatible dimensions as input0. - * * 2: An INT32 value, and has to be one of the {@link FuseCode} values. - * Specifies the activation to invoke on the result of each addition. - * - * Outputs: - * * 0: The product, a tensor of the same type as input0. - */ ANEURALNETWORKS_MUL = 18, - /** Computes rectified linear activation on the input tensor element-wise. - * - * The output is calculated using this formula: - * - * output = max(0, input) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: up to 4. - * - * Inputs: - * * 0: A tensor, specifying the input. - * - * Outputs: - * * 0: The output tensor of same shape as input0. - */ ANEURALNETWORKS_RELU = 19, - /** Computes rectified linear 1 activation on the input tensor element-wise. - * - * The output is calculated using this formula: - * - * output = min(1.f, max(-1.f, input)) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: up to 4. - * - * Inputs: - * * 0: A tensor, specifying the input. - * - * Outputs: - * * 0: The output tensor of same shape as input0. - */ ANEURALNETWORKS_RELU1 = 20, - /** Computes rectified linear 6 activation on the input tensor element-wise. - * - * The output is calculated using this formula: - * - * output = min(6, max(0, input)) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: up to 4. - * - * Inputs: - * * 0: A tensor, specifying the input. - * - * Outputs: - * * 0: The output tensor of same shape as input0. - */ ANEURALNETWORKS_RELU6 = 21, - /** Reshapes a tensor. - * - * Given tensor, this operation returns a tensor that has the same values as - * tensor, but with a newly specified shape. - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: up to 4. - * - * Inputs: - * * 0: A tensor, specifying the tensor to be reshaped. - * * 1: A 1-D tensor of type {@link ANEURALNETWORKS_TENSOR_INT32}, defining - * the shape of the output tensor. The number of elements implied by shape - * must be the same as the number of elements in the input tensor. - * - * Outputs: - * * 0: The output tensor, of shape specified by the input shape. - */ ANEURALNETWORKS_RESHAPE = 22, - /** Resizes images to given size using the bilinear interpretation. - * - * Resized images will be distorted if their original aspect ratio is not the - * same as input. - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * - * Supported tensor rank: 4, with "NHWC" data layout. - * - * Inputs: - * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the - * input. - * * 1: An INT32 value, specifying the output width of the output tensor. - * * 2: An INT32 value, specifying the output height of the output tensor. - * - * Outputs: - * * 0: The output 4-D tensor, of shape [batches, new_height, new_width, - * depth]. - */ ANEURALNETWORKS_RESIZE_BILINEAR = 23, - - /** - * A basic recurrent neural network layer. - * - * This layer implements the operation: - * outputs = state = activation(inputs * input_weights + state * - * recurrent_weights + bias) - * - * Where: - * * “input_weights” is a weight matrix that multiplies the inputs; - * * “recurrent_weights” is a weight matrix that multiplies the current - * “state” which itself is the output from the previous time step - * computation; - * * “bias” is a bias vector (added to each output vector in the batch); - * * “activation” is the function passed as the “fused_activation_function” - * argument (if not “NONE”). - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * - * Inputs: - * * 0: input. - * A 2-D tensor of type T, of shape [batch_size, input_size], where - * “batch_size” corresponds to the batching dimension, and “input_size” - * is the size of the input. - * * 1: weights. - * A 2-D tensor of type T, of shape [num_units, input_size], where - * “num_units” corresponds to the number of units. - * * 2: recurrent_weights. - * A 2-D tensor of type T, of shape [num_units, num_units], with columns - * corresponding to the weights from each unit. - * * 3: bias. - * A 1-D tensor of type T, of shape [num_units]. - * - * For FLOAT32 input tensor, bias must also be FLOAT32. - * For UINT8 input tensor, bias must be INT32. - * - * Parameters - * * 4: fused_activation_function. - * An (optional) ActivationFunctionType indicating the activation - * function. If “NONE” is specified then it results in a linear - * activation. - * - * * 5: Hidden state. - * A 2-D tensor of type T, of shape [batch_size, num_units]. - * - * Outputs: - * * 0: output. - * A 2-D tensor of type T, of shape [batch_size, num_units]. This is - * effectively the same as the current state value. - */ ANEURALNETWORKS_RNN = 24, - - /** Computes the softmax activation on the input tensor element-wise, per - * batch, by normalizing the input vector so the maximum coefficient is zero. - * - * The output is calculated using this formula: - * - * output[batch, i] = - * exp((input[batch, i] - max(input[batch, :])) * beta) / - * sum_{k}{exp((input[batch, k] - max(input[batch, :])) * beta)} - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: 2 or 4. - * - * Inputs: - * * 0: A 2-D or 4-D tensor, specifying the tensor to be reshaped. - * * 1: A FLOAT32 value, specifying the scaling factor for the exponent, beta. - * - * Outputs: - * * 0: The output tensor of same shape as input0. - */ ANEURALNETWORKS_SOFTMAX = 25, - - /** Rearranges blocks of spatial data, into depth. - * - * More specifically, this op outputs a copy of the input tensor where values - * from the height and width dimensions are moved to the depth dimension. The - * value block_size indicates the input block size and how the data is moved. - * - * Chunks of data of size block_size * block_size from depth are rearranged - * into non-overlapping blocks of size block_size x block_size. - * - * The depth of the output tensor is input_depth * block_size * block_size. - * The input tensor's height and width must be divisible by block_size. - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} - * - * Supported tensor rank: 4, with "NHWC" data layout. - * - * Inputs: - * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying - * the input. - * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and - * block_size must be a divisor of both the input height and width. - * - * Outputs: - * * 0: The output 4-D tensor, of shape [batch, height/block_size, - * width/block_size, depth*block_size*block_size]. - */ ANEURALNETWORKS_SPACE_TO_DEPTH = 26, - - /** - * SVDF op is a kind of stateful layer derived from the notion that a - * densely connected layer that's processing a sequence of input frames can - * be approximated by using a singular value decomposition of each of its - * nodes. The implementation is based on: - * - * https://research.google.com/pubs/archive/43813.pdf - * - * P. Nakkiran, R. Alvarez, R. Prabhavalkar, C. Parada. - * “Compressing Deep Neural Networks using a Rank-Constrained Topology”. - * INTERSPEECH, 2015. - * - * It processes the incoming input using a 2-stage filtering mechanism: - * * stage 1 performs filtering on the "features" dimension, whose outputs get - * pushed into a memory of fixed-size memory_size. - * * stage 2 performs filtering on the "time" dimension of the memory_size - * memoized outputs of stage 1. - * - * Specifically, for rank 1, this layer implements the operation: - * - * memory = push(conv1d(inputs, weights_feature, feature_dim, "VALID")); - * outputs = activation(memory * weights_time + bias); - * - * Where: - * * “weights_feature” is a weights matrix that processes the inputs (by - * convolving the input with every “feature filter”), and whose outputs get - * pushed, stacked in order, into the fixed-size “memory” (the oldest entry - * gets dropped); - * * “weights_time” is a weights matrix that processes the “memory” (by a - * batched matrix multiplication on the num_units); - * * “bias” is an optional bias vector (added to each output vector in the - * batch); and - * * “activation” is the function passed as the “fused_activation_function” - * argument (if not “NONE”). - * - * Each rank adds a dimension to the weights matrices by means of stacking - * the filters. - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * - * Inputs: - * * 0: input. - * A 2-D tensor of type T, of shape [batch_size, input_size], where - * “batch_size” corresponds to the batching dimension, and “input_size” - * is the size of the input. - * * 1: weights_feature. - * A 2-D tensor of type T, of shape [num_units, input_size], where - * “num_units” corresponds to the number of units. - * * 2: weights_time. - * A 2-D tensor of type T, of shape [num_units, memory_size], where - * “memory_size” corresponds to the fixed-size of the memory. - * * 3: bias. - * A optional 1-D tensor of type T, of shape [num_units]. - * - * For FLOAT32 input tensor, bias must also be FLOAT32. - * For UINT8 input tensor, bias must be INT32. - * - * Parameters: - * * 4: rank. - * The rank of the SVD approximation. - * * 5: fused_activation_function. - * An (optional) ActivationFunctionType indicating the activation - * function. If “NONE” is specified then it results in a linear activation. - * - * Outputs: - * * 0: state. - * A 2-D tensor of type T, of shape [batch_size, (memory_size - 1) * - * num_units * rank]. - * * 1: output. - * A 2-D tensor of type T, of shape [batch_size, num_units]. - */ ANEURALNETWORKS_SVDF = 27, - - /** Computes hyperbolic tangent of input tensor element-wise. - * - * The output is calculated using this formula: - * - * output = tanh(input) - * - * Supported tensor types: - * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} - * - * Supported tensor rank: up to 4. - * - * Inputs: - * * 0: A tensor, specifying the input. - * - * Outputs: - * * 0: The output tensor of same shape as input0. - */ ANEURALNETWORKS_TANH = 28, + ANEURALNETWORKS_BATCH_TO_SPACE_ND = 29, + ANEURALNETWORKS_DIV = 30, + ANEURALNETWORKS_MEAN = 31, + ANEURALNETWORKS_PAD = 32, + ANEURALNETWORKS_SPACE_TO_BATCH_ND = 33, + ANEURALNETWORKS_SQUEEZE = 34, + ANEURALNETWORKS_STRIDED_SLICE = 35, + ANEURALNETWORKS_SUB = 36, + ANEURALNETWORKS_TRANSPOSE = 37, }; /** @@ -1080,13 +137,9 @@ enum { * */ enum { - /** NO fused activation function. */ ANEURALNETWORKS_FUSED_NONE = 0, - /** Fused ReLU activation function. */ ANEURALNETWORKS_FUSED_RELU = 1, - /** Fused ReLU1 activation function. */ ANEURALNETWORKS_FUSED_RELU1 = 2, - /** Fused ReLU6 activation function. */ ANEURALNETWORKS_FUSED_RELU6 = 3, }; @@ -1094,20 +147,8 @@ enum { * Execution preferences. */ enum { - /** - * Prefer executing in a way that minimizes battery drain. - * This is desirable for compilations that will be executed often. - */ ANEURALNETWORKS_PREFER_LOW_POWER = 0, - /** - * Prefer returning a single answer as fast as possible, even if this causes - * more power consumption. - */ ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1, - /** - * Prefer maximizing the throughput of successive frames, for example when - * processing successive frames coming from the camera. - */ ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2, }; diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index eb451397bd8effdfca0eaba55cb6ce230718b2b9..999c31d4bff9279810a3661f0bb342cc4ef3ddaa 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -23,6 +23,10 @@ limitations under the License. #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" +#ifdef __ANDROID__ +#include +#endif + namespace tflite { // TODO(aselle): FATAL leaves resources hanging. @@ -46,6 +50,32 @@ void FATAL(const char* format, ...) { FATAL("Aborting since tflite returned failure."); \ } +namespace { + +int32_t GetAndroidSdkVersion() { +#ifdef __ANDROID__ + const char* sdkProp = "ro.build.version.sdk"; + char sdkVersion[PROP_VALUE_MAX]; + int length = __system_property_get(sdkProp, sdkVersion); + if (length != 0) { + for (int i = 0; i < length; ++i) { + int digit = sdkVersion[i] - '0'; + if (digit < 0 || digit > 9) { + // Non-numeric SDK version, assume it's higher then expected; + return 0xFFFF; + } + } + return atoi(sdkVersion); + } + FATAL("No %s prop", sdkProp); +#endif // __ANDROID__ + return 0; +} + +static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion(); + +} // namespace + NNAPIAllocation::NNAPIAllocation(const char* filename, ErrorReporter* error_reporter) : MMAPAllocation(filename, error_reporter) { @@ -125,7 +155,6 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, nn_type, static_cast(tensor->dims->size), reinterpret_cast(tensor->dims->data), scale, zeroPoint}; CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); - // TODO(aselle): Based on Michael's suggestion, limiting this to read // only memory if (tensor->allocation_type == kTfLiteMmapRo) { @@ -138,7 +167,12 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, CHECK_NN(ANeuralNetworksModel_setOperandValue( nn_model, next_id, tensor->data.raw, tensor->bytes)); } + } else if (tensor->bytes == 0) { + // These size 0 tensors are optional tensors reserved. + CHECK_NN( + ANeuralNetworksModel_setOperandValue(nn_model, next_id, nullptr, 0)); } + ++next_id; } return next_id; @@ -147,7 +181,9 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, // Adds the operations and their parameters to the NN API model. // 'next-id' is the operand ID of the next operand of the model. void AddOpsAndParams(tflite::Interpreter* interpreter, - ANeuralNetworksModel* nn_model, uint32_t next_id) { + ANeuralNetworksModel* nn_model, uint32_t next_id, + std::vector* model_state_inputs, + std::vector* model_state_outputs) { for (size_t i = 0; i < interpreter->nodes_size(); i++) { const auto* node_and_registration = interpreter->node_and_registration(i); const TfLiteNode& node = node_and_registration->first; @@ -158,6 +194,8 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, // Add the parameters. std::vector augmented_inputs( node.inputs->data, node.inputs->data + node.inputs->size); + std::vector augmented_outputs( + node.outputs->data, node.outputs->data + node.outputs->size); auto add_scalar_int32 = [&nn_model, &augmented_inputs, &next_id](int value) { @@ -177,15 +215,29 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, augmented_inputs.push_back(next_id++); }; + // Handle state tensors of RNN, LSTM, SVDF. + // For each state_out tensor, a corresponding state_in operand needs to be + // created for NNAPI. auto duplicate_state_tensor_float32 = - [interpreter, &nn_model, &augmented_inputs](int tensor_id) { + [interpreter, &nn_model, &next_id, &augmented_inputs, + &model_state_inputs, &model_state_outputs](int tensor_id) { const TfLiteTensor* tensor = interpreter->tensor(tensor_id); - CHECK_NN(ANeuralNetworksModel_setOperandValue( - nn_model, tensor_id, tensor->data.raw, tensor->bytes)); - augmented_inputs.push_back(tensor_id); + ANeuralNetworksOperandType operand_type{ + ANEURALNETWORKS_TENSOR_FLOAT32, + static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), + tensor->params.scale, tensor->params.zero_point}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + augmented_inputs.push_back(next_id); + model_state_inputs->push_back(next_id); + model_state_outputs->push_back(tensor_id); + next_id++; }; - auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); }; + auto add_add_params = [&add_scalar_int32](void* data) { + auto* builtin = reinterpret_cast(data); + add_scalar_int32(builtin->activation); + }; auto add_pooling_params = [&add_scalar_int32](void* data) { auto builtin = reinterpret_cast(data); @@ -245,33 +297,62 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, add_scalar_float32(builtin->proj_clip); }; -#if 0 - auto add_reshape_params = [&](void* data) { - auto builtin = reinterpret_cast(data); - uint32_t tensor_size_shape = builtin->num_dimensions; + // LSTM in NNAPI requires scratch tensor as an output operand. + auto add_lstm_scratch_tensor_float32 = [interpreter, &node, &nn_model, + &next_id, &augmented_outputs]() { + int scratch_buffer_index = node.temporaries->data[0]; + const TfLiteTensor* tensor = interpreter->tensor(scratch_buffer_index); ANeuralNetworksOperandType operand_type{ - ANEURALNETWORKS_TENSOR_INT32, - {static_cast(1), - reinterpret_cast(&tensor_size_shape)}, - 0, - 0}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) - CHECK_NN(ANeuralNetworksModel_setOperandValue( - nn_model, next_id, builtin->shape, - sizeof(int) * builtin->num_dimensions)); - augmented_inputs.push_back(next_id++); + ANEURALNETWORKS_TENSOR_FLOAT32, + static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), tensor->params.scale, + tensor->params.zero_point}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + augmented_outputs.insert(augmented_outputs.begin(), next_id++); + }; + + auto add_mean_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->keep_dims); + }; + + auto add_svdf_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->rank); + add_scalar_int32(builtin->activation); }; -#endif + auto add_rnn_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->activation); + }; + + // Handle optional input tensors. + auto add_optional_tensors = [&nn_model, &augmented_inputs, + &next_id](int nn_type) { + for (size_t idx = 0; idx < augmented_inputs.size(); idx++) { + if (augmented_inputs[idx] == kOptionalTensor) { + const std::vector dim = {0, 0}; + ANeuralNetworksOperandType operand_type{nn_type, 2, dim.data(), 0, 0}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, + nullptr, 0)) + augmented_inputs[idx] = next_id++; + } + } + }; + + int nnapi_version = 10; ANeuralNetworksOperationType nn_op_type; + switch (builtin) { case tflite::BuiltinOperator_ADD: nn_op_type = ANEURALNETWORKS_ADD; - add_add_params(); + add_add_params(node.builtin_data); break; case tflite::BuiltinOperator_MUL: nn_op_type = ANEURALNETWORKS_MUL; - add_add_params(); + add_add_params(node.builtin_data); break; case tflite::BuiltinOperator_AVERAGE_POOL_2D: add_pooling_params(node.builtin_data); @@ -330,27 +411,58 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, break; case tflite::BuiltinOperator_LSTM: { duplicate_state_tensor_float32( - node.outputs->data[/*kOutputStateTensor*/ 1]); + node.outputs->data[/*kOutputStateTensor*/ 0]); duplicate_state_tensor_float32( - node.outputs->data[/*kCellStateTensor*/ 2]); + node.outputs->data[/*kCellStateTensor*/ 1]); add_lstm_params(node.builtin_data); + add_lstm_scratch_tensor_float32(); + add_optional_tensors(ANEURALNETWORKS_TENSOR_FLOAT32); nn_op_type = ANEURALNETWORKS_LSTM; break; } + case tflite::BuiltinOperator_SVDF: { + duplicate_state_tensor_float32(node.outputs->data[/*kStateTensor*/ 0]); + add_svdf_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_SVDF; + break; + } + case tflite::BuiltinOperator_RNN: { + duplicate_state_tensor_float32( + node.outputs->data[/*kHiddenStateTensor*/ 0]); + add_rnn_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_RNN; + break; + } + case tflite::BuiltinOperator_EMBEDDING_LOOKUP: + nn_op_type = ANEURALNETWORKS_EMBEDDING_LOOKUP; + break; + case tflite::BuiltinOperator_PAD: + nnapi_version = 11; // require NNAPI 1.1 + nn_op_type = ANEURALNETWORKS_PAD; + break; + case tflite::BuiltinOperator_MEAN: + nnapi_version = 11; // require NNAPI 1.1 + add_mean_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_MEAN; + break; + case tflite::BuiltinOperator_DIV: + nnapi_version = 11; // require NNAPI 1.1 + nn_op_type = ANEURALNETWORKS_DIV; + break; + case tflite::BuiltinOperator_SUB: + nnapi_version = 11; // require NNAPI 1.1 + nn_op_type = ANEURALNETWORKS_SUB; + break; case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: case tflite::BuiltinOperator_LSH_PROJECTION: - case tflite::BuiltinOperator_SVDF: case tflite::BuiltinOperator_HASHTABLE_LOOKUP: - case tflite::BuiltinOperator_RNN: case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: - case tflite::BuiltinOperator_EMBEDDING_LOOKUP: case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case tflite::BuiltinOperator_L2_NORMALIZATION: case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: - case tflite::BuiltinOperator_PAD: case tflite::BuiltinOperator_PADV2: case tflite::BuiltinOperator_RESIZE_BILINEAR: case tflite::BuiltinOperator_CALL: @@ -361,9 +473,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_BATCH_TO_SPACE_ND: case tflite::BuiltinOperator_TOPK_V2: case tflite::BuiltinOperator_TRANSPOSE: - case tflite::BuiltinOperator_MEAN: - case tflite::BuiltinOperator_DIV: - case tflite::BuiltinOperator_SUB: case tflite::BuiltinOperator_SPLIT: case tflite::BuiltinOperator_SQUEEZE: case tflite::BuiltinOperator_STRIDED_SLICE: @@ -382,6 +491,15 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_LESS_EQUAL: case tflite::BuiltinOperator_NEG: case tflite::BuiltinOperator_SELECT: + case tflite::BuiltinOperator_SLICE: + case tflite::BuiltinOperator_SIN: + case tflite::BuiltinOperator_LOG: + case tflite::BuiltinOperator_TRANSPOSE_CONV: + case tflite::BuiltinOperator_TILE: + case tflite::BuiltinOperator_EXPAND_DIMS: + case tflite::BuiltinOperator_SPARSE_TO_DENSE: + case tflite::BuiltinOperator_EQUAL: + case tflite::BuiltinOperator_NOT_EQUAL: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; @@ -391,11 +509,16 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, break; } + if (nnapi_version == 11 && kAndroidSdkVersion < 28) { + FATAL("Op %d needs NNAPI1.1", builtin); + } + // Add the operation. CHECK_NN(ANeuralNetworksModel_addOperation( nn_model, nn_op_type, static_cast(augmented_inputs.size()), - augmented_inputs.data(), static_cast(node.outputs->size), - reinterpret_cast(node.outputs->data))); + augmented_inputs.data(), + static_cast(augmented_outputs.size()), + reinterpret_cast(augmented_outputs.data()))); } } @@ -419,12 +542,25 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { } uint32_t next_id = addTensorOperands(interpreter, nn_model_, skip_list); - AddOpsAndParams(interpreter, nn_model_, next_id); + AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_, + &model_states_outputs_); + + std::vector augmented_inputs = interpreter->inputs(); + std::vector augmented_outputs = interpreter->outputs(); + + // All state tensors input/output need to be treated as model input/output. + augmented_inputs.insert(augmented_inputs.end(), + model_states_inputs_.begin(), + model_states_inputs_.end()); + augmented_outputs.insert(augmented_outputs.end(), + model_states_outputs_.begin(), + model_states_outputs_.end()); + CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs( - nn_model_, static_cast(interpreter->inputs().size()), - reinterpret_cast(interpreter->inputs().data()), - static_cast(interpreter->outputs().size()), - reinterpret_cast(interpreter->outputs().data()))); + nn_model_, static_cast(augmented_inputs.size()), + reinterpret_cast(augmented_inputs.data()), + static_cast(augmented_outputs.size()), + reinterpret_cast(augmented_outputs.data()))); CHECK_NN(ANeuralNetworksModel_finish(nn_model_)); } if (!nn_compiled_model_) { @@ -451,6 +587,7 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { CHECK_NN(ANeuralNetworksExecution_setInput( execution, i, nullptr, tensor->data.raw, tensor->bytes)); } + // Tell nn api where to place final data. for (size_t i = 0; i < interpreter->outputs().size(); i++) { int output = interpreter->outputs()[i]; @@ -458,6 +595,24 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { CHECK_NN(ANeuralNetworksExecution_setOutput( execution, i, nullptr, tensor->data.raw, tensor->bytes)); } + + // The state_out of previous invocation need to be mapped to state_in of + // current invocation. + for (size_t i = 0; i < model_states_outputs_.size(); i++) { + int state_tensor_idx = model_states_outputs_[i]; + TfLiteTensor* tensor = interpreter->tensor(state_tensor_idx); + // Here we are using a deep copy for state_in tensors so that we are not + // reading and writing into the same buffer during a invocation. + // TODO(miaowang): using double shared buffer to minimize the copies. + CHECK_NN(ANeuralNetworksExecution_setInput( + execution, i + interpreter->inputs().size(), nullptr, tensor->data.raw, + tensor->bytes)); + // Tell NNAPI where to output the state_out. + CHECK_NN(ANeuralNetworksExecution_setOutput( + execution, i + interpreter->outputs().size(), nullptr, tensor->data.raw, + tensor->bytes)); + } + // Currently use blocking compute. ANeuralNetworksEvent* event = nullptr; CHECK_NN(ANeuralNetworksExecution_startCompute(execution, &event)); diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h index e98000929a1168c786f6c18f498f9d1d72311ada..94dea4f9b23f208fddbacd3c77d889ea753a8a1d 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.h +++ b/tensorflow/contrib/lite/nnapi_delegate.h @@ -59,6 +59,14 @@ class NNAPIDelegate { ANeuralNetworksModel* nn_model_ = nullptr; // The NN API compilation handle ANeuralNetworksCompilation* nn_compiled_model_ = nullptr; + + // List of state tensors for LSTM, RNN, SVDF. + // NN API does not allow ops to maintain states across multiple + // invocations. We need to manually create state input tensors from + // corresponding state output tensors of TFLite operations, and map them + // correctly. + std::vector model_states_inputs_; + std::vector model_states_outputs_; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/op_resolver.cc new file mode 100644 index 0000000000000000000000000000000000000000..f6e435e982443218b5e6328e48e1ce0d2393224c --- /dev/null +++ b/tensorflow/contrib/lite/op_resolver.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +const TfLiteRegistration* MutableOpResolver::FindOp(tflite::BuiltinOperator op, + int version) const { + auto it = builtins_.find(std::make_pair(op, version)); + return it != builtins_.end() ? &it->second : nullptr; +} + +const TfLiteRegistration* MutableOpResolver::FindOp(const char* op, + int version) const { + auto it = custom_ops_.find(std::make_pair(op, version)); + return it != custom_ops_.end() ? &it->second : nullptr; +} + +void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, + TfLiteRegistration* registration, + int min_version, int max_version) { + for (int version = min_version; version <= max_version; ++version) { + TfLiteRegistration new_registration = *registration; + new_registration.builtin_code = op; + new_registration.version = version; + auto op_key = std::make_pair(op, version); + builtins_[op_key] = new_registration; + } +} + +void MutableOpResolver::AddCustom(const char* name, + TfLiteRegistration* registration, + int min_version, int max_version) { + for (int version = min_version; version <= max_version; ++version) { + TfLiteRegistration new_registration = *registration; + new_registration.builtin_code = BuiltinOperator_CUSTOM; + new_registration.version = version; + auto op_key = std::make_pair(name, version); + custom_ops_[op_key] = new_registration; + } +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h new file mode 100644 index 0000000000000000000000000000000000000000..9d7e3f20854a3596181ffa885cc17cfdbd16356e --- /dev/null +++ b/tensorflow/contrib/lite/op_resolver.h @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ +#define TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/util.h" + +namespace tflite { + +// Abstract interface that returns TfLiteRegistrations given op codes or custom +// op names. This is the mechanism that ops being referenced in the flatbuffer +// model are mapped to executable function pointers (TfLiteRegistrations). +class OpResolver { + public: + // Finds the op registration for a builtin operator by enum code. + virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const = 0; + // Finds the op registration of a custom operator by op name. + virtual const TfLiteRegistration* FindOp(const char* op, + int version) const = 0; + virtual ~OpResolver() {} +}; + +// Some versions of gcc doesn't support partial specialization in class scope, +// so these are defined in a namescope. +namespace op_resolver_hasher { +template +struct ValueHasher { + size_t operator()(const V& v) const { return std::hash()(v); } +}; + +template <> +struct ValueHasher { + size_t operator()(const tflite::BuiltinOperator& v) const { + return std::hash()(static_cast(v)); + } +}; + +template +struct OperatorKeyHasher { + size_t operator()(const T& x) const { + size_t a = ValueHasher()(x.first); + size_t b = ValueHasher()(x.second); + return CombineHashes({a, b}); + } +}; +} // namespace op_resolver_hasher + +// An OpResolver that is mutable, also used as the op in gen_op_registration. +// A typical usage: +// MutableOpResolver resolver; +// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD()); +// resolver.AddCustom("CustomOp", Register_CUSTOM_OP()); +// InterpreterBuilder(model, resolver)(&interpreter); +class MutableOpResolver : public OpResolver { + public: + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override; + const TfLiteRegistration* FindOp(const char* op, int version) const override; + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); + void AddCustom(const char* name, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); + + private: + typedef std::pair BuiltinOperatorKey; + typedef std::pair CustomOperatorKey; + + std::unordered_map > + builtins_; + std::unordered_map > + custom_ops_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/op_resolver_test.cc b/tensorflow/contrib/lite/op_resolver_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..10b7e319722faabd5c1c7442d1c820649affd1ca --- /dev/null +++ b/tensorflow/contrib/lite/op_resolver_test.cc @@ -0,0 +1,129 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/op_resolver.h" + +#include +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace { + +// We need some dummy functions to identify the registrations. +TfLiteStatus DummyInvoke(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteRegistration* GetDummyRegistration() { + static TfLiteRegistration registration = { + .init = nullptr, + .free = nullptr, + .prepare = nullptr, + .invoke = DummyInvoke, + }; + return ®istration; +} + +TEST(MutableOpResolverTest, FinOp) { + MutableOpResolver resolver; + resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration()); + + const TfLiteRegistration* found_registration = + resolver.FindOp(BuiltinOperator_ADD, 1); + ASSERT_NE(found_registration, nullptr); + EXPECT_TRUE(found_registration->invoke == DummyInvoke); + EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_ADD); + EXPECT_EQ(found_registration->version, 1); +} + +TEST(MutableOpResolverTest, FindMissingOp) { + MutableOpResolver resolver; + resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration()); + + const TfLiteRegistration* found_registration = + resolver.FindOp(BuiltinOperator_CONV_2D, 1); + EXPECT_EQ(found_registration, nullptr); +} + +TEST(MutableOpResolverTest, RegisterOpWithMultipleVersions) { + MutableOpResolver resolver; + // The kernel supports version 2 and 3 + resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3); + + const TfLiteRegistration* found_registration; + + found_registration = resolver.FindOp(BuiltinOperator_ADD, 2); + ASSERT_NE(found_registration, nullptr); + EXPECT_TRUE(found_registration->invoke == DummyInvoke); + EXPECT_EQ(found_registration->version, 2); + + found_registration = resolver.FindOp(BuiltinOperator_ADD, 3); + ASSERT_NE(found_registration, nullptr); + EXPECT_TRUE(found_registration->invoke == DummyInvoke); + EXPECT_EQ(found_registration->version, 3); +} + +TEST(MutableOpResolverTest, FindOpWithUnsupportedVersions) { + MutableOpResolver resolver; + // The kernel supports version 2 and 3 + resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3); + + const TfLiteRegistration* found_registration; + + found_registration = resolver.FindOp(BuiltinOperator_ADD, 1); + EXPECT_EQ(found_registration, nullptr); + + found_registration = resolver.FindOp(BuiltinOperator_ADD, 4); + EXPECT_EQ(found_registration, nullptr); +} + +TEST(MutableOpResolverTest, FindCustomOp) { + MutableOpResolver resolver; + resolver.AddCustom("AWESOME", GetDummyRegistration()); + + const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 1); + ASSERT_NE(found_registration, nullptr); + EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_CUSTOM); + EXPECT_TRUE(found_registration->invoke == DummyInvoke); + EXPECT_EQ(found_registration->version, 1); + // TODO(ycling): The `custom_name` in TfLiteRegistration isn't properly + // filled yet. Fix this and add tests. +} + +TEST(MutableOpResolverTest, FindMissingCustomOp) { + MutableOpResolver resolver; + resolver.AddCustom("AWESOME", GetDummyRegistration()); + + const TfLiteRegistration* found_registration = + resolver.FindOp("EXCELLENT", 1); + EXPECT_EQ(found_registration, nullptr); +} + +TEST(MutableOpResolverTest, FindCustomOpWithUnsupportedVersion) { + MutableOpResolver resolver; + resolver.AddCustom("AWESOME", GetDummyRegistration()); + + const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 2); + EXPECT_EQ(found_registration, nullptr); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD index 15999e5d4188db1e191936ae6d84faf8cce5ca6e..a162b87b8f98576ec7c3b2623d1d34f2baef6cce 100644 --- a/tensorflow/contrib/lite/profiling/BUILD +++ b/tensorflow/contrib/lite/profiling/BUILD @@ -2,9 +2,11 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + common_copts = [ "-Wall", -] +] + tflite_copts() cc_library( name = "profiler", @@ -29,6 +31,43 @@ cc_library( name = "profile_buffer", hdrs = ["profile_buffer.h"], copts = common_copts, + deps = [":time"], +) + +cc_library( + name = "time", + srcs = ["time.cc"], + hdrs = ["time.h"], + copts = common_copts, +) + +cc_library( + name = "profile_summarizer", + srcs = ["profile_summarizer.cc"], + hdrs = ["profile_summarizer.h"], + copts = common_copts, + deps = [ + ":profiler", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/core:stats_calculator_portable", + ], +) + +cc_test( + name = "profile_summarizer_test", + srcs = ["profile_summarizer_test.cc"], + copts = common_copts, + deps = [ + ":profile_summarizer", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], ) cc_test( diff --git a/tensorflow/contrib/lite/profiling/profile_buffer.h b/tensorflow/contrib/lite/profiling/profile_buffer.h index 299b2a9cad161ce05ba68f39cf612f9866a0b656..65d86dce47f397c7dad6cc2beb8ffa1f95b29d45 100644 --- a/tensorflow/contrib/lite/profiling/profile_buffer.h +++ b/tensorflow/contrib/lite/profiling/profile_buffer.h @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/profiling/time.h" + namespace tflite { namespace profiling { @@ -74,7 +76,7 @@ class ProfileBuffer { if (!enabled_) { return kInvalidEventHandle; } - uint64_t timestamp = NowMicros(); + uint64_t timestamp = time::NowMicros(); int index = current_index_ % event_buffer_.size(); event_buffer_[index].tag = tag; event_buffer_[index].event_type = event_type; @@ -103,7 +105,7 @@ class ProfileBuffer { } int event_index = event_handle % max_size; - event_buffer_[event_index].end_timestamp_us = NowMicros(); + event_buffer_[event_index].end_timestamp_us = time::NowMicros(); } // Returns the size of the buffer. @@ -134,12 +136,6 @@ class ProfileBuffer { } private: - static uint64_t NowMicros() { - // TODO(shashishekhar): Refactor this to a separate file. - struct timeval tv; - gettimeofday(&tv, nullptr); - return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; - } bool enabled_; uint32_t current_index_; std::vector event_buffer_; diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..45388b500c7897c8b33b49eb6ab4e9f8c4fdb37c --- /dev/null +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc @@ -0,0 +1,148 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" + +#include + +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { +namespace profiling { +namespace { + +using Detail = tensorflow::StatsCalculator::Detail; + +struct OperatorDetails { + std::string name; + std::vector inputs; + std::vector outputs; +}; + +std::string GetTensorName(const tflite::Interpreter& interpreter, + int tensor_index) { + const auto tensor = interpreter.tensor(tensor_index); + if (tensor == nullptr || tensor->name == nullptr) { + return "Unknown"; + } + return tensor->name; +} +std::vector GetTensorNames(const tflite::Interpreter& interpreter, + const TfLiteIntArray* tensor_indices) { + std::vector tensors; + tensors.reserve(tensor_indices->size); + for (int i = 0; i < tensor_indices->size; i++) { + tensors.push_back(GetTensorName(interpreter, tensor_indices->data[i])); + } + return tensors; +} + +std::string ToString(const std::vector& str_vector) { + std::stringstream stream; + stream << "["; + bool first = true; + for (const auto& s : str_vector) { + if (!first) { + stream << ", "; + } else { + first = false; + } + stream << s; + } + stream << "]"; + return stream.str(); +} + +OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter, + int node_index) { + auto node_reg = interpreter.node_and_registration(node_index); + auto inputs = node_reg->first.inputs; + auto outputs = node_reg->first.outputs; + int code = node_reg->second.builtin_code; + const char* op_name = nullptr; + if (code == tflite::BuiltinOperator_CUSTOM) { + const char* custom_name = node_reg->second.custom_name; + op_name = custom_name ? custom_name : "UnknownCustomOp"; + } else { + op_name = tflite::EnumNamesBuiltinOperator()[code]; + } + OperatorDetails details; + details.name = op_name; + details.inputs = GetTensorNames(interpreter, inputs); + details.outputs = GetTensorNames(interpreter, outputs); + return details; +} + +tensorflow::StatSummarizerOptions GetProfileSummarizerOptions() { + auto options = tensorflow::StatSummarizerOptions(); + options.show_summary = true; + options.show_memory = false; + return options; +} + +} // namespace + +ProfileSummarizer::ProfileSummarizer() + : stats_calculator_( + new ::tensorflow::StatsCalculator(GetProfileSummarizerOptions())) {} + +void ProfileSummarizer::ProcessProfiles( + const std::vector& profile_stats, + const tflite::Interpreter& interpreter) { + std::vector events; + std::copy_if(profile_stats.begin(), profile_stats.end(), + std::back_inserter(events), [](const ProfileEvent* e) { + return e->event_type == + ProfileEvent::EventType::OPERATOR_INVOKE_EVENT && + e->end_timestamp_us >= e->begin_timestamp_us; + }); + // Sort with begin_time. + std::sort(events.begin(), events.end(), + [](const ProfileEvent* const& a, const ProfileEvent* const& b) { + return a->begin_timestamp_us < b->begin_timestamp_us; + }); + if (events.empty()) { + return; + } + + int64_t base_start_us = events[0]->begin_timestamp_us; + int node_num = 0; + int64_t curr_total_us = 0; + std::map details; + for (auto event : events) { + auto op_details = GetOperatorDetails(interpreter, event->event_metadata); + auto node_name = ToString(op_details.outputs); + auto result = details.emplace(node_name, Detail()); + Detail* detail = &(result.first->second); + detail->start_us.UpdateStat(event->begin_timestamp_us - base_start_us); + int64_t node_exec_time = + event->end_timestamp_us - event->begin_timestamp_us; + detail->rel_end_us.UpdateStat(node_exec_time); + curr_total_us += node_exec_time; + ++node_num; + + if (result.second) { + detail->name = node_name; + detail->type = op_details.name; + detail->run_order = node_num; + detail->times_called = 0; + } + ++detail->times_called; + } + stats_calculator_->UpdateDetails(details); + stats_calculator_->UpdateRunTotalUs(curr_total_us); +} +} // namespace profiling +} // namespace tflite diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.h b/tensorflow/contrib/lite/profiling/profile_summarizer.h new file mode 100644 index 0000000000000000000000000000000000000000..a529ff87428d70d002241311d7f70f185521020f --- /dev/null +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ +#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ + +#include + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/profiling/profiler.h" +#include "tensorflow/core/util/stats_calculator.h" + +namespace tflite { +namespace profiling { + +// Creates a summary of operator invocations in the interpreter. +class ProfileSummarizer { + public: + ProfileSummarizer(); + virtual ~ProfileSummarizer() {} + + // Process profile events to update statistics for operator invocations. + void ProcessProfiles(const std::vector& profile_stats, + const tflite::Interpreter& interpreter); + + // Returns a string detailing the accumulated runtime stats in a tab-separated + // format which can be pasted into a spreadsheet for further analysis. + std::string GetOutputString() const { + return stats_calculator_->GetOutputString(); + } + + std::string GetShortSummary() const { + return stats_calculator_->GetShortSummary(); + } + + private: + std::unique_ptr stats_calculator_; +}; + +} // namespace profiling +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..35cf780713b93db559f86dcaf62e1ac004b5049a --- /dev/null +++ b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc @@ -0,0 +1,116 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { +namespace profiling { + +namespace { + +TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0); + const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1); + + TfLiteTensor* output = GetOutput(context, node, /*index=*/0); + + int32_t* output_data = output->data.i32; + *output_data = *(input1->data.i32) + *(input2->data.i32); + return kTfLiteOk; +} + +TfLiteRegistration* RegisterSimpleOp() { + static TfLiteRegistration registration = {nullptr, + nullptr, + nullptr, + SimpleOpEval, + tflite::BuiltinOperator_CUSTOM, + "SimpleOpEval", + 1}; + return ®istration; +} + +class SimpleOpModel : public SingleOpModel { + public: + void Init(); + tflite::Interpreter* GetInterpreter() { return interpreter_.get(); } + void SetInputs(int32_t x, int32_t y) { + PopulateTensor(inputs_[0], {x}); + PopulateTensor(inputs_[1], {y}); + } + int32_t GetOutput() { return ExtractVector(output_)[0]; } + + private: + int inputs_[2]; + int output_; +}; + +void SimpleOpModel::Init() { + inputs_[0] = AddInput({TensorType_INT32, {1}}); + inputs_[1] = AddInput({TensorType_INT32, {1}}); + output_ = AddOutput({TensorType_INT32, {}}); + SetCustomOp("SimpleAdd", {}, RegisterSimpleOp); + BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])}); +} + +TEST(ProfileSummarizerTest, Empty) { + ProfileSummarizer summarizer; + std::string output = summarizer.GetOutputString(); + EXPECT_GT(output.size(), 0); +} + +#ifdef TFLITE_PROFILING_ENABLED +TEST(ProfileSummarizerTest, Interpreter) { + Profiler profiler; + SimpleOpModel m; + m.Init(); + auto interpreter = m.GetInterpreter(); + interpreter->SetProfiler(&profiler); + profiler.StartProfiling(); + m.SetInputs(1, 2); + m.Invoke(); + // 3 = 1 + 2 + EXPECT_EQ(m.GetOutput(), 3); + profiler.StopProfiling(); + ProfileSummarizer summarizer; + auto events = profiler.GetProfileEvents(); + EXPECT_EQ(1, events.size()); + summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter); + auto output = summarizer.GetOutputString(); + // TODO(shashishekhar): Add a better test here. + ASSERT_TRUE(output.find("SimpleOp") != std::string::npos) << output; +} +#endif + +} // namespace +} // namespace profiling +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/profiling/time.cc b/tensorflow/contrib/lite/profiling/time.cc new file mode 100644 index 0000000000000000000000000000000000000000..446660bb747cd6e3b694669b64ac1d95cf415fbe --- /dev/null +++ b/tensorflow/contrib/lite/profiling/time.cc @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/profiling/time.h" + +#include + +namespace tflite { +namespace profiling { +namespace time { +uint64_t NowMicros() { + struct timeval tv; + gettimeofday(&tv, nullptr); + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; +} +} // namespace time +} // namespace profiling +} // namespace tflite diff --git a/tensorflow/contrib/lite/profiling/time.h b/tensorflow/contrib/lite/profiling/time.h new file mode 100644 index 0000000000000000000000000000000000000000..cc2ec319b8a95b3efa0aab0ac9f97a88bf7b5536 --- /dev/null +++ b/tensorflow/contrib/lite/profiling/time.h @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ +#define TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ + +#include + +namespace tflite { +namespace profiling { +namespace time { +uint64_t NowMicros(); +} // namespace time +} // namespace profiling +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 4920e83970d1cb7f60a38f95ea05986d52b0bbe7..27909a9458f6b09f96cb556a5254f01e54f46e05 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -36,6 +36,16 @@ py_test( ], ) +py_binary( + name = "tflite_convert", + srcs = ["tflite_convert.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":lite", + ], +) + py_library( name = "lite", srcs = ["lite.py"], @@ -45,7 +55,23 @@ py_library( ":convert", ":convert_saved_model", ":interpreter", + ":lite_constants", ":op_hint", + "//tensorflow/python:graph_util", + "//tensorflow/python/saved_model:constants", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/tools:freeze_graph_lib", + ], +) + +py_test( + name = "lite_test", + srcs = ["lite_test.py"], + data = [":interpreter_test_data"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":lite", ], ) @@ -111,9 +137,9 @@ py_library( visibility = ["//visibility:public"], deps = [ ":convert", - ":lite_constants", "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python:graph_util", + "//tensorflow/python:platform", "//tensorflow/python/tools:freeze_graph_lib", ], ) @@ -150,20 +176,3 @@ py_test( "//tensorflow/python/saved_model", ], ) - -py_binary( - name = "convert_saved_model_to_frozen_graph", - srcs = ["convert_saved_model_to_frozen_graph.py"], - srcs_version = "PY2AND3", - deps = [ - ":convert_saved_model", - ], -) - -# Transitive dependencies of this target will be included in the pip package. -py_library( - name = "tf_lite_py_pip", - deps = [ - ":convert_saved_model", - ], -) diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index c4200c879ba0e17b3bd183f4004eb75ebdd2f5ee..df39d7ff50f79d433dafa88ba057110634183e19 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -111,37 +111,75 @@ def tensor_name(x): return x.name.split(":")[0] -def toco_convert(input_data, - input_tensors, - output_tensors, - inference_type=lite_constants.FLOAT, - input_format=lite_constants.TENSORFLOW_GRAPHDEF, - output_format=lite_constants.TFLITE, - quantized_input_stats=None, - drop_control_dependency=True): - """Convert a model using TOCO from `input_format` to `output_format`. +def build_toco_convert_protos(input_tensors, + output_tensors, + inference_type=lite_constants.FLOAT, + inference_input_type=None, + input_format=lite_constants.TENSORFLOW_GRAPHDEF, + output_format=lite_constants.TFLITE, + quantized_input_stats=None, + default_ranges_stats=None, + drop_control_dependency=True, + reorder_across_fake_quant=False, + allow_custom_ops=False, + change_concat_input_ranges=False, + quantize_weights=False, + dump_graphviz_dir=None, + dump_graphviz_video=False): + """Builds protocol buffers describing a conversion of a model using TOCO. Typically this is to convert from TensorFlow GraphDef to TFLite, in which case the default `input_format` and `output_format` are sufficient. Args: - input_data: Input data (i.e. often `sess.graph_def`). input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT) - quantized_input_stats: For each member of input_tensors the mean and - std deviation of training data. Only needed if `inference_type` is - `QUANTIZED_UINT8`. - drop_control_dependency: Drops control dependencies silently. This is due - to tf lite not supporting control dependencies. + inference_type: Target data type of arrays in the output file. Currently + must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default FLOAT) + inference_input_type: Target data type of input arrays. Allows for a + different type for input arrays in the case of quantization. Currently + must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default `inference_type`) + input_format: Type of data to read Currently must be + `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) + output_format: Output file format. Currently must be `{TFLITE, + GRAPHVIZ_DOT}`. (default TFLITE) + quantized_input_stats: List of tuples of integers representing the mean and + standard deviation. Each tuple maps to the corresponding input tensor. + Only need if `inference_type` is `QUANTIZED_UINT8`. (default None) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) + drop_control_dependency: Boolean indicating whether to drop control + dependencies silently. This is due to TFLite not supporting control + dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. + (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) + quantize_weights: Boolean indicating whether to store weights as quantized + weights followed by dequantize operations. Computation is still done in + float, but reduces model size (at the cost of accuracy and latency). + (default False) + dump_graphviz_dir: Full filepath of folder to dump the graphs at various + stages of processing GraphViz .dot files. Preferred over + --output_format=GRAPHVIZ_DOT in order to keep the requirements of the + output file. (default None) + dump_graphviz_video: Boolean indicating whether to dump the graph after + every graph transformation. (default False) Returns: - The converted data. For example if tflite was the destination, then - this will be a tflite flatbuffer in a bytes array. + model_flags, toco_flags: two protocol buffers describing the conversion + process. Raises: ValueError: If the input tensor type is unknown @@ -151,9 +189,21 @@ def toco_convert(input_data, toco = _toco_flags_pb2.TocoFlags() toco.input_format = input_format toco.output_format = output_format + toco.inference_type = inference_type + if inference_input_type: + toco.inference_input_type = inference_input_type toco.drop_control_dependency = drop_control_dependency + toco.reorder_across_fake_quant = reorder_across_fake_quant + toco.allow_custom_ops = allow_custom_ops + toco.quantize_weights = quantize_weights + if default_ranges_stats: + toco.default_ranges_min = default_ranges_stats[0] + toco.default_ranges_max = default_ranges_stats[1] + if dump_graphviz_dir: + toco.dump_graphviz_dir = dump_graphviz_dir + toco.dump_graphviz_include_video = dump_graphviz_video model = _model_flags_pb2.ModelFlags() - toco.inference_type = inference_type + model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): if input_tensor.dtype == _dtypes.float32: tflite_input_type = lite_constants.FLOAT @@ -161,7 +211,10 @@ def toco_convert(input_data, tflite_input_type = lite_constants.INT32 elif input_tensor.dtype == _dtypes.int64: tflite_input_type = lite_constants.INT64 - # TODO(aselle): Insert strings when they are available + elif input_tensor.dtype == _dtypes.uint8: + tflite_input_type = lite_constants.QUANTIZED_UINT8 + elif input_tensor.dtype == _dtypes.string: + tflite_input_type = lite_constants.STRING else: raise ValueError("Tensors %s not known type %r" % (input_tensor.name, input_tensor.dtype)) @@ -178,10 +231,35 @@ def toco_convert(input_data, for output_tensor in output_tensors: model.output_arrays.append(tensor_name(output_tensor)) + return model, toco + + +def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): + """"Convert a model using TOCO. + + Typically this function is used to convert from TensorFlow GraphDef to TFLite. + Conversion can be customized by providing arguments that are forwarded to + `build_toco_convert_protos` (see documentation for details). + + Args: + input_data: Input data (i.e. often `sess.graph_def`), + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + *args: See `build_toco_convert_protos`, + **kwargs: See `build_toco_convert_protos`. + + Returns: + The converted data. For example if TFLite was the destination, then + this will be a tflite flatbuffer in a bytes array. - # TODO(aselle): Consider handling the case of allowing quantized - # inputs to be converted to float (via the toco.inference_input_type field). - data = toco_convert_protos(model.SerializeToString(), - toco.SerializeToString(), + Raises: + Defined in `build_toco_convert_protos`. + """ + model_flags, toco_flags = build_toco_convert_protos(input_tensors, + output_tensors, + *args, **kwargs) + data = toco_convert_protos(model_flags.SerializeToString(), + toco_flags.SerializeToString(), input_data.SerializeToString()) return data diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index a7eddf3408f54dff5fa49ff6fa7b61cd0b8a22e4..1553464b9fe30f596c151bcc67efe891bb913ba3 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -18,34 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.lite.python import convert -from tensorflow.contrib.lite.python import lite_constants -from tensorflow.contrib.lite.toco import model_flags_pb2 -from tensorflow.contrib.saved_model.python.saved_model import reader -from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils +from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework import ops -from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import tag_constants - - -def _write_and_flush_file(file_path, data_str): - """Writes data to file path. - - Args: - file_path: Full path of the file to store data in. - data_str: Data represented as a string. - - Returns: None. - """ - with gfile.Open(file_path, "wb") as data_file: - data_file.write(data_str) - data_file.flush() def _log_tensor_details(tensor_info): @@ -77,21 +57,8 @@ def _get_meta_graph_def(saved_model_dir, tag_set): Raises: ValueError: No valid MetaGraphDef for given tag_set. """ - saved_model = reader.read_saved_model(saved_model_dir) - tag_sets = [] - result_meta_graph_def = None - for meta_graph_def in saved_model.meta_graphs: - meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags) - tag_sets.append(meta_graph_tag_set) - if meta_graph_tag_set == tag_set: - result_meta_graph_def = meta_graph_def - logging.info("The given saved_model contains the following tags: %s", - tag_sets) - if result_meta_graph_def is not None: - return result_meta_graph_def - else: - raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible " - "values are '{}'. ".format(tag_set, tag_sets)) + with session.Session(graph=ops.Graph()) as sess: + return loader.load(sess, tag_set, saved_model_dir) def _get_signature_def(meta_graph, signature_key): @@ -110,15 +77,13 @@ def _get_signature_def(meta_graph, signature_key): signature_def_map = meta_graph.signature_def signature_def_keys = set(signature_def_map.keys()) logging.info( - "The given saved_model MetaGraphDef contains SignatureDefs with the " + "The given SavedModel MetaGraphDef contains SignatureDefs with the " "following keys: %s", signature_def_keys) if signature_key not in signature_def_keys: - raise ValueError("No '{}' in the saved_model\'s SignatureDefs. Possible " - "values are '{}'. ".format(signature_key, - signature_def_keys)) - signature_def = signature_def_utils.get_signature_def_by_key( - meta_graph, signature_key) - return signature_def + raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible " + "values are '{}'.".format(signature_key, + ",".join(signature_def_keys))) + return signature_def_map[signature_key] def _get_inputs_outputs(signature_def): @@ -170,29 +135,10 @@ def _get_tensors(graph, signature_def_tensor_names=None, """ tensors = [] if user_tensor_names: - # Get the list of all of the tensors with and without the tensor index. - all_tensor_names = [ - tensor.name for op in graph.get_operations() for tensor in op.outputs - ] - all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names] - # Sort the tensor names. user_tensor_names = sorted(user_tensor_names) - # Get the tensors associated with the tensor names. - tensors = [] - invalid_tensors = [] - for name in user_tensor_names: - if name not in all_tensor_names_only: - invalid_tensors.append(name) - else: - idx = all_tensor_names_only.index(name) - tensors.append(graph.get_tensor_by_name(all_tensor_names[idx])) - - # Throw ValueError if any user input names are not valid tensors. - if invalid_tensors: - raise ValueError("Invalid tensors '{}' were found.".format( - ",".join(invalid_tensors))) + tensors = get_tensors_from_tensor_names(graph, user_tensor_names) elif signature_def_tensor_names: tensors = [ graph.get_tensor_by_name(name) @@ -207,25 +153,74 @@ def _get_tensors(graph, signature_def_tensor_names=None, return tensors -def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, - output_arrays, tag_set, signature_key, batch_size): +def get_tensors_from_tensor_names(graph, tensor_names): + """Gets the Tensors associated with the `tensor_names` in the provided graph. + + Args: + graph: TensorFlow Graph. + tensor_names: List of strings that represent names of tensors in the graph. + + Returns: + A list of Tensor objects in the same order the names are provided. + + Raises: + ValueError: + tensor_names contains an invalid tensor name. + """ + # Get the list of all of the tensors. + tensor_name_to_tensor = { + tensor_name(tensor): tensor for op in graph.get_operations() + for tensor in op.values() + } + + # Get the tensors associated with tensor_names. + tensors = [] + invalid_tensors = [] + for name in tensor_names: + tensor = tensor_name_to_tensor.get(name) + if tensor is None: + invalid_tensors.append(name) + else: + tensors.append(tensor) + + # Throw ValueError if any user input names are not valid tensors. + if invalid_tensors: + raise ValueError("Invalid tensors '{}' were found.".format( + ",".join(invalid_tensors))) + return tensors + + +def set_tensor_shapes(tensors, shapes): + """Sets Tensor shape for each tensor if the shape is defined. + + Args: + tensors: TensorFlow ops.Tensor. + shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). + """ + if shapes: + for tensor in tensors: + shape = shapes.get(tensor_name(tensor)) + if shape is not None: + tensor.set_shape(shape) + + +def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, + output_arrays, tag_set, signature_key): """Converts a SavedModel to a frozen graph. Args: saved_model_dir: SavedModel directory to convert. input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of + from SignatureDef when none are provided. + input_shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) + arrays from SignatureDef when none are provided. tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") + analyze. All tags in the tag set must be present. signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) Returns: frozen_graph_def: Frozen GraphDef. @@ -236,186 +231,32 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, ValueError: SavedModel doesn't contain a MetaGraphDef identified by tag_set. signature_key is not in the MetaGraphDef. + assets/ directory is in the MetaGraphDef. input_shapes does not match the length of input_arrays. - input_shapes has a None value after the 1st dimension. input_arrays or output_arrays are not valid. - Unable to load Session. """ - # Set default values for inputs if they are set to None. - if signature_key is None: - signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - if tag_set is None: - tag_set = set([tag_constants.SERVING]) - if batch_size is None: - batch_size = 1 - # Read SignatureDef. meta_graph = _get_meta_graph_def(saved_model_dir, tag_set) signature_def = _get_signature_def(meta_graph, signature_key) inputs, outputs = _get_inputs_outputs(signature_def) + # Check SavedModel for assets directory. + collection_def = meta_graph.collection_def + if constants.ASSETS_KEY in collection_def: + raise ValueError("SavedModels with assets/ directory are not supported.") + graph = ops.Graph() with session.Session(graph=graph) as sess: - # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory. loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir) # Gets input and output tensors. # TODO(zhixianyan): Use TFLite supported Op list to filter outputs. in_tensors = _get_tensors(graph, inputs, input_arrays) out_tensors = _get_tensors(graph, outputs, output_arrays) - - # Gets fully defined tensor shape. An input tensor with None in the first - # dimension, e.g. (None, 224, 224, 3), is replaced with the batch_size. - # Shapes with None after the first dimension result in a ValueError. - # TODO(zhixianyan): Add supports for input tensor with more None in shape. - for tensor in in_tensors: - if (input_shapes and tensor.name in input_shapes and - input_shapes[tensor.name] is not None): - shape = input_shapes[tensor.name] - else: - shape = tensor.get_shape().as_list() - - if None in shape[1:]: - raise ValueError( - "None is only supported in the 1st dimension. Tensor '{0}' has " - "invalid shape '{1}'.".format(tensor.name, shape)) - elif shape[0] is None: - shape[0] = batch_size - tensor.set_shape(shape) + set_tensor_shapes(in_tensors, input_shapes) output_names = [node.split(":")[0] for node in outputs] frozen_graph_def = tf_graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), output_names) return frozen_graph_def, in_tensors, out_tensors - raise ValueError("Unable to load Session.") - - -def saved_model_to_frozen_graphdef( - saved_model_dir, - output_file_model, - output_file_flags, - input_arrays=None, - input_shapes=None, - output_arrays=None, - tag_set=None, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - batch_size=1): - """Converts a SavedModel to a frozen graph. Writes graph to tmp directory. - - Stores frozen graph and command line flags in the tmp directory. - - Args: - saved_model_dir: SavedModel directory to convert. - output_file_model: Full file path to save frozen graph. - output_file_flags: Full file path to save ModelFlags. - input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of - integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). - Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) - output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) - tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") - signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) - - Returns: None. - - Raises: - ValueError: Unable to convert to frozen graph. - """ - frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( - saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, - signature_key, batch_size) - - # Initialize model flags. - model = model_flags_pb2.ModelFlags() - - for input_tensor in in_tensors: - input_array = model.input_arrays.add() - input_array.name = convert.tensor_name(input_tensor) - input_array.shape.dims.extend(map(int, input_tensor.get_shape())) - - for output_tensor in out_tensors: - model.output_arrays.append(convert.tensor_name(output_tensor)) - - # Write model and ModelFlags to file. ModelFlags contain input array and - # output array information that is parsed from the SignatureDef and used for - # analysis by TOCO. - _write_and_flush_file(output_file_model, frozen_graph_def.SerializeToString()) - _write_and_flush_file(output_file_flags, model.SerializeToString()) - - -def tflite_from_saved_model( - saved_model_dir, - output_file=None, - input_arrays=None, - input_shapes=None, - output_arrays=None, - tag_set=None, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - batch_size=1, - inference_type=lite_constants.FLOAT, - input_format=lite_constants.TENSORFLOW_GRAPHDEF, - output_format=lite_constants.TFLITE, - quantized_input_stats=None, - drop_control_dependency=True): - """Converts a SavedModel to TFLite FlatBuffer. - - Args: - saved_model_dir: SavedModel directory to convert. - output_file: File path to write result TFLite FlatBuffer. - input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of - integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). - Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) - output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) - tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") - signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT) - quantized_input_stats: For each member of input_tensors the mean and - std deviation of training data. Only needed if `inference_type` is - `QUANTIZED_UINT8`. - drop_control_dependency: Drops control dependencies silently. This is due - to tf lite not supporting control dependencies. - - Returns: - The converted data. For example if tflite was the destination, then - this will be a tflite flatbuffer in a bytes array. - - Raises: - ValueError: Unable to convert to frozen graph. - """ - frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( - saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, - signature_key, batch_size) - - result = convert.toco_convert( - input_data=frozen_graph_def, - input_tensors=in_tensors, - output_tensors=out_tensors, - inference_type=inference_type, - input_format=input_format, - output_format=output_format, - quantized_input_stats=quantized_input_stats, - drop_control_dependency=drop_control_dependency) - - if output_file is not None: - with gfile.Open(output_file, "wb") as f: - f.write(result) - logging.info("Successfully converted to: %s", output_file) - - return result diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py index db95fc8ad7f94b52d33c72f6ec5819bdfe8cf05f..92c4ebb2465c2abaa1cefd020e69b2f7ad6a54a5 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -25,12 +25,12 @@ from __future__ import print_function import os from tensorflow.contrib.lite.python import convert_saved_model -from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.estimator import estimator_lib as estimator from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.layers import layers from tensorflow.python.ops import array_ops @@ -38,13 +38,68 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.losses import losses -from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import training as train -class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): +class TensorFunctionsTest(test_util.TensorFlowTestCase): + + def testGetTensorsValid(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + tensors = convert_saved_model.get_tensors_from_tensor_names( + sess.graph, ["Placeholder"]) + self.assertEqual("Placeholder:0", tensors[0].name) + + def testGetTensorsInvalid(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + with self.assertRaises(ValueError) as error: + convert_saved_model.get_tensors_from_tensor_names(sess.graph, + ["invalid-input"]) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + def testSetTensorShapeValid(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]}) + self.assertEqual([5, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeNoneValid(self): + tensor = array_ops.placeholder(dtype=dtypes.float32) + self.assertEqual(None, tensor.shape) + + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]}) + self.assertEqual([1, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeInvalid(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], + {"invalid-input": [5, 3, 5]}) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeEmpty(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], {}) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + +class FreezeSavedModelTest(test_util.TensorFlowTestCase): def _createSimpleSavedModel(self, shape): """Create a simple SavedModel on the fly.""" @@ -57,82 +112,167 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): saved_model.simple_save(sess, saved_model_dir, inputs, outputs) return saved_model_dir + def _createSavedModelTwoInputArrays(self, shape): + """Create a simple SavedModel.""" + saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") + with session.Session() as sess: + in_tensor_1 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name="inputB") + in_tensor_2 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name="inputA") + out_tensor = in_tensor_1 + in_tensor_2 + inputs = {"x": in_tensor_1, "y": in_tensor_2} + outputs = {"z": out_tensor} + saved_model.simple_save(sess, saved_model_dir, inputs, outputs) + return saved_model_dir + + def _getArrayNames(self, tensors): + return [tensor.name for tensor in tensors] + + def _getArrayShapes(self, tensors): + dims = [] + for tensor in tensors: + dim_tensor = [] + for dim in tensor.shape: + if isinstance(dim, tensor_shape.Dimension): + dim_tensor.append(dim.value) + else: + dim_tensor.append(dim) + dims.append(dim_tensor) + return dims + + def _convertSavedModel(self, + saved_model_dir, + input_arrays=None, + input_shapes=None, + output_arrays=None, + tag_set=None, + signature_key=None): + if tag_set is None: + tag_set = set([tag_constants.SERVING]) + if signature_key is None: + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model( + saved_model_dir=saved_model_dir, + input_arrays=input_arrays, + input_shapes=input_shapes, + output_arrays=output_arrays, + tag_set=tag_set, + signature_key=signature_key) + return graph_def, in_tensors, out_tensors + def testSimpleSavedModel(self): - """Test a simple SavedModel created on the fly.""" - # Create a simple SavedModel + """Test a SavedModel.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Convert to tflite - result = convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir) - self.assertTrue(result) + _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) def testSimpleSavedModelWithNoneBatchSizeInShape(self): - """Test a simple SavedModel, with None in input tensor's shape.""" + """Test a SavedModel with None in input tensor's shape.""" saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) - result = convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir) - self.assertTrue(result) + _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir) - def testSimpleSavedModelWithMoreNoneInShape(self): - """Test a simple SavedModel, fail as more None in input shape.""" - saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, None, 3]) - # Convert to tflite: this should raise ValueError, as 3rd dim is None. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir) + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]]) - def testSimpleSavedModelWithWrongSignatureKey(self): - """Test a simple SavedModel, fail as given signature is invalid.""" + def testSimpleSavedModelWithInvalidSignatureKey(self): + """Test a SavedModel that fails due to an invalid signature_key.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Convert to tflite: this should raise ValueError, as - # signature_key does not exit in the saved_model. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, signature_key="wrong-key") - - def testSimpleSavedModelWithWrongOutputArray(self): - """Test a simple SavedModel, fail as given output_arrays is invalid.""" - # Create a simple SavedModel + with self.assertRaises(ValueError) as error: + self._convertSavedModel(saved_model_dir, signature_key="invalid-key") + self.assertEqual( + "No 'invalid-key' in the SavedModel's SignatureDefs. " + "Possible values are 'serving_default'.", str(error.exception)) + + def testSimpleSavedModelWithInvalidOutputArray(self): + """Test a SavedModel that fails due to invalid output arrays.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Convert to tflite: this should raise ValueError, as - # output_arrays is not valid for the saved_model. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, output_arrays=["wrong-output"]) + with self.assertRaises(ValueError) as error: + self._convertSavedModel(saved_model_dir, output_arrays=["invalid-output"]) + self.assertEqual("Invalid tensors 'invalid-output' were found.", + str(error.exception)) def testSimpleSavedModelWithWrongInputArrays(self): - """Test a simple SavedModel, fail as given input_arrays is invalid.""" + """Test a SavedModel that fails due to invalid input arrays.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Checks invalid input_arrays. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, input_arrays=["wrong-input"]) - # Checks valid and invalid input_arrays. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, - input_arrays=["Placeholder", "wrong-input"]) + + # Check invalid input_arrays. + with self.assertRaises(ValueError) as error: + self._convertSavedModel(saved_model_dir, input_arrays=["invalid-input"]) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + # Check valid and invalid input_arrays. + with self.assertRaises(ValueError) as error: + self._convertSavedModel( + saved_model_dir, input_arrays=["Placeholder", "invalid-input"]) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) def testSimpleSavedModelWithCorrectArrays(self): - """Test a simple SavedModel, with correct input_arrays and output_arrays.""" + """Test a SavedModel with correct input_arrays and output_arrays.""" saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) - result = convert_saved_model.tflite_from_saved_model( + _, in_tensors, out_tensors = self._convertSavedModel( saved_model_dir=saved_model_dir, input_arrays=["Placeholder"], output_arrays=["add"]) - self.assertTrue(result) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]]) def testSimpleSavedModelWithCorrectInputArrays(self): - """Test a simple SavedModel, with correct input_arrays and input_shapes.""" + """Test a SavedModel with correct input_arrays and input_shapes.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - result = convert_saved_model.tflite_from_saved_model( + _, in_tensors, out_tensors = self._convertSavedModel( saved_model_dir=saved_model_dir, input_arrays=["Placeholder"], input_shapes={"Placeholder": [1, 16, 16, 3]}) - self.assertTrue(result) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) + + def testTwoInputArrays(self): + """Test a simple SavedModel.""" + saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3]) + + _, in_tensors, out_tensors = self._convertSavedModel( + saved_model_dir=saved_model_dir, input_arrays=["inputB", "inputA"]) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0", "inputB:0"]) + self.assertEqual( + self._getArrayShapes(in_tensors), [[1, 16, 16, 3], [1, 16, 16, 3]]) + + def testSubsetInputArrays(self): + """Test a SavedModel with a subset of the input array names of the model.""" + saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3]) + + # Check case where input shape is given. + _, in_tensors, out_tensors = self._convertSavedModel( + saved_model_dir=saved_model_dir, + input_arrays=["inputA"], + input_shapes={"inputA": [1, 16, 16, 3]}) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) + + # Check case where input shape is None. + _, in_tensors, out_tensors = self._convertSavedModel( + saved_model_dir=saved_model_dir, input_arrays=["inputA"]) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) def testMultipleMetaGraphDef(self): - """Test saved model with multiple MetaGraphDef.""" + """Test saved model with multiple MetaGraphDefs.""" saved_model_dir = os.path.join(self.get_temp_dir(), "savedmodel_two_mgd") builder = saved_model.builder.SavedModelBuilder(saved_model_dir) with session.Session(graph=ops.Graph()) as sess: @@ -161,91 +301,13 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): builder.save(True) # Convert to tflite - convert_saved_model.tflite_from_saved_model( + _, in_tensors, out_tensors = self._convertSavedModel( saved_model_dir=saved_model_dir, tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"])) - -class ConvertSavedModelTestBasicGraphToText(test_util.TensorFlowTestCase): - - def _createSimpleSavedModel(self, shape): - """Create a simple SavedModel.""" - saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") - with session.Session() as sess: - in_tensor_1 = array_ops.placeholder( - shape=shape, dtype=dtypes.float32, name="inputB") - in_tensor_2 = array_ops.placeholder( - shape=shape, dtype=dtypes.float32, name="inputA") - out_tensor = in_tensor_1 + in_tensor_2 - inputs = {"x": in_tensor_1, "y": in_tensor_2} - outputs = {"z": out_tensor} - saved_model.simple_save(sess, saved_model_dir, inputs, outputs) - return saved_model_dir - - def _getInputArrayNames(self, model_proto): - return [data.name for data in model_proto.input_arrays] - - def _getInputArrayShapes(self, model_proto): - return [ - [dim for dim in data.shape.dims] for data in model_proto.input_arrays - ] - - def _get_model_flags_proto_from_file(self, filename): - proto = _model_flags_pb2.ModelFlags() - with gfile.Open(filename, "rb") as output_file: - proto.ParseFromString(output_file.read()) - output_file.close() - return proto - - def testSimpleSavedModel(self): - """Test a simple SavedModel.""" - saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - output_file_model = os.path.join(self.get_temp_dir(), "model.pb") - output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt") - - convert_saved_model.saved_model_to_frozen_graphdef( - saved_model_dir=saved_model_dir, - output_file_model=output_file_model, - output_file_flags=output_file_flags, - input_arrays=["inputB", "inputA"]) - - proto = self._get_model_flags_proto_from_file(output_file_flags) - self.assertEqual(proto.output_arrays, ["add"]) - self.assertEqual(self._getInputArrayNames(proto), ["inputA", "inputB"]) - self.assertEqual( - self._getInputArrayShapes(proto), [[1, 16, 16, 3], [1, 16, 16, 3]]) - - def testSimpleSavedModelWithDifferentInputNames(self): - """Test a simple SavedModel.""" - saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - output_file_model = os.path.join(self.get_temp_dir(), "model.pb") - output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt") - - # Check case where input shape is given. - convert_saved_model.saved_model_to_frozen_graphdef( - saved_model_dir=saved_model_dir, - output_file_model=output_file_model, - output_file_flags=output_file_flags, - input_arrays=["inputA"], - input_shapes={"inputA": [1, 16, 16, 3]}) - - proto = self._get_model_flags_proto_from_file(output_file_flags) - self.assertEqual(proto.output_arrays, ["add"]) - self.assertEqual(self._getInputArrayNames(proto), ["inputA"]) - self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]]) - - # Check case where input shape is None. - convert_saved_model.saved_model_to_frozen_graphdef( - saved_model_dir=saved_model_dir, - output_file_model=output_file_model, - output_file_flags=output_file_flags, - input_arrays=["inputA"], - input_shapes={"inputA": None}) - - proto = self._get_model_flags_proto_from_file(output_file_flags) - self.assertEqual(proto.output_arrays, ["add"]) - self.assertEqual(self._getInputArrayNames(proto), ["inputA"]) - self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]]) + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 28, 28]]) class Model(keras.Model): @@ -354,7 +416,7 @@ def dummy_input_fn(): return image, labels -class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase): +class FreezeSavedModelTestTrainGraph(test_util.TensorFlowTestCase): def testTrainedMnistSavedModel(self): """Test mnist SavedModel, trained with dummy data and small steps.""" @@ -379,13 +441,16 @@ class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase): # Convert to tflite and test output saved_model_name = os.listdir(saved_model_dir)[0] saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name) - output_file = os.path.join(saved_model_dir, saved_model_final_dir + ".lite") + # TODO(zhixianyan): no need to limit output_arrays to `Softmax' # once b/74205001 fixed and argmax implemented in tflite. - result = convert_saved_model.tflite_from_saved_model( + result = convert_saved_model.freeze_saved_model( saved_model_dir=saved_model_final_dir, + input_arrays=None, + input_shapes=None, output_arrays=["Softmax"], - output_file=output_file) + tag_set=set([tag_constants.SERVING]), + signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) self.assertTrue(result) diff --git a/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py b/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py deleted file mode 100644 index 4d9782f4a6a9e853c3afdbd97d4264a818937e63..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Python console command for generating frozen models from SavedModels. - -This exists to add SavedModel compatibility to TOCO. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import sys -from tensorflow.contrib.lite.python.convert_saved_model import saved_model_to_frozen_graphdef -from tensorflow.python.platform import app - -FLAGS = None - - -def execute(unused_args): - """Calls function to convert the SavedModel to a frozen graph.""" - # Error handling. - if FLAGS.input_shapes and not FLAGS.input_arrays: - raise ValueError("Input shapes requires input arrays to be specified.") - - # Calls saved_model_to_frozen_graphdef function to generate frozen graph. - input_arrays = (FLAGS.input_arrays.split(",") if FLAGS.input_arrays else None) - input_shapes = None - if FLAGS.input_shapes: - input_shapes = { - input_arrays[idx]: shape.split(",") - for idx, shape in enumerate(FLAGS.input_shapes.split(":")) - } - output_arrays = ( - FLAGS.output_arrays.split(",") if FLAGS.output_arrays else None) - tag_set = set(FLAGS.tag_set.split(",")) if FLAGS.tag_set else None - - saved_model_to_frozen_graphdef( - saved_model_dir=FLAGS.saved_model_directory, - output_file_model=FLAGS.output_file_model, - output_file_flags=FLAGS.output_file_flags, - input_arrays=input_arrays, - input_shapes=input_shapes, - output_arrays=output_arrays, - tag_set=tag_set, - signature_key=FLAGS.signature_key, - batch_size=FLAGS.batch_size) - - -def main(): - global FLAGS - # Parses flags. - parser = argparse.ArgumentParser( - description="Invoke SavedModel to frozen model converter.") - parser.add_argument( - "saved_model_directory", - type=str, - help="Full path to directory containing the SavedModel.") - parser.add_argument( - "output_file_model", - type=str, - help="Full file path to save frozen graph.") - parser.add_argument( - "output_file_flags", type=str, help="Full file path to save ModelFlags.") - parser.add_argument( - "--input_arrays", - type=str, - help="Name of the input arrays, comma-separated.") - parser.add_argument( - "--input_shapes", - type=str, - help="Shapes corresponding to --input_arrays, colon-separated.") - parser.add_argument( - "--output_arrays", - type=str, - help="Name of the output arrays, comma-separated.") - parser.add_argument( - "--tag_set", type=str, help="Name of output arrays, comma-separated.") - parser.add_argument( - "--signature_key", - type=str, - help="Key identifying SignatureDef containing inputs and outputs.") - parser.add_argument( - "--batch_size", - type=int, - help="Batch size for the model. Replaces the first dimension of an " - "input size array if undefined.") - - FLAGS, unparsed = parser.parse_known_args() - - app.run(main=execute, argv=[sys.argv[0]] + unparsed) - - -if __name__ == "__main__": - main() diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index 5fbc55145217dd8a4e3eec4108e18d1c2be5c883..779bda4c9d05fd056d6a262412fdcf0d47e7c57c 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -54,7 +54,7 @@ class Interpreter(object): elif model_content and not model_path: self._interpreter = ( _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( - model_content, len(model_content))) + model_content)) if not self._interpreter: raise ValueError( 'Failed to create model from {} bytes'.format(len(model_content))) diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD index 453eda6e7345762666917fd501b69c7181c349e8..12ab38847dc0f838ae2c6bf80ed80805285e4b8b 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD @@ -15,7 +15,7 @@ cc_library( "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/core:lib", "//tensorflow/python:numpy_lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", ], ) @@ -27,6 +27,6 @@ tf_py_wrap_cc( ], deps = [ ":interpreter_wrapper_lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 16f4f30b94313453c3a4f0496a6ac5649847f076..5f304ad45d400b13e20bda8184b5b40cfe13f6c2 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -42,6 +42,8 @@ std::unique_ptr CreateInterpreter( return nullptr; } + tensorflow::ImportNumpy(); + std::unique_ptr interpreter; tflite::InterpreterBuilder(*model, resolver)(&interpreter); if (interpreter) { @@ -331,9 +333,14 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( - const char* data, size_t len) { + PyObject* data) { + char * buf = nullptr; + Py_ssize_t length; + if (PY_TO_CPPSTRING(data, &buf, &length) == -1) { + return nullptr; + } std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(data, len); + tflite::FlatBufferModel::BuildFromBuffer(buf, length); return model ? new InterpreterWrapper(std::move(model)) : nullptr; } diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index 0972c572595f5044a305a81afaccbea5f131247c..c02aa3804367f787016ef78fc8557005507f051b 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +// Place `` before to avoid build failures in macOS. +#include #include // We forward declare TFLite classes here to avoid exposing them to SWIG. @@ -40,8 +42,7 @@ class InterpreterWrapper { static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path); // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromBuffer(const char* data, - size_t len); + static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data); ~InterpreterWrapper(); bool AllocateTensors(); diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 86b25e68acaf5d74e3dd11784446e7bda3d329ee..1e76157d2f947e43000085e99edeabab54a1ebd8 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -16,23 +16,377 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. +@@TocoConverter @@toco_convert @@toco_convert_protos -@@tflite_from_saved_model @@Interpreter @@OpHint @@convert_op_hints_to_stubs +@@FLOAT +@@QUANTIZED_UINT8 +@@STRING +@@TFLITE +@@GRAPHVIZ_DOT + """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import +from six import PY3 + +from google.protobuf import text_format as _text_format +from google.protobuf.message import DecodeError +from tensorflow.contrib.lite.python import lite_constants as constants +from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import +from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.contrib.lite.python.convert import toco_convert -from tensorflow.contrib.lite.python.convert import toco_convert_protos -from tensorflow.contrib.lite.python.convert_saved_model import tflite_from_saved_model -from tensorflow.contrib.lite.python.interpreter import Interpreter -from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs -from tensorflow.contrib.lite.python.op_hint import OpHint -# pylint: enable=unused-import +from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import +from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model +from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names +from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes +from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import +from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import +from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import +from tensorflow.core.framework import graph_pb2 as _graph_pb2 +from tensorflow.python.client import session as _session +from tensorflow.python.framework import graph_util as tf_graph_util +from tensorflow.python.framework.importer import import_graph_def +from tensorflow.python.ops.variables import global_variables_initializer +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants + + +class TocoConverter(object): + """Convert a TensorFlow model into `output_format` using TOCO. + + This is used to convert from a TensorFlow GraphDef or SavedModel into either a + TFLite FlatBuffer or graph visualization. + + Attributes: + + inference_type: Target data type of arrays in the output file. Currently + must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default FLOAT) + inference_input_type: Target data type of input arrays. Allows for a + different type for input arrays in the case of quantization. Currently + must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default `inference_type`) + output_format: Output file format. Currently must be `{TFLITE, + GRAPHVIZ_DOT}`. (default TFLITE) + quantized_input_stats: Dict of strings representing input tensor names + mapped to tuple of integers representing the mean and standard deviation + of the training data (e.g., {"foo" : (0., 1.)}). Only need if + `inference_type` is `QUANTIZED_UINT8`. (default {}) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) + drop_control_dependency: Boolean indicating whether to drop control + dependencies silently. This is due to TFLite not supporting control + dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) + allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. + (default False) + quantize_weights: Boolean indicating whether to store weights as quantized + weights followed by dequantize operations. Computation is still done in + float, but reduces model size (at the cost of accuracy and latency). + (default False) + dump_graphviz_dir: Full filepath of folder to dump the graphs at various + stages of processing GraphViz .dot files. Preferred over + --output_format=GRAPHVIZ_DOT in order to keep the requirements of the + output file. (default None) + dump_graphviz_video: Boolean indicating whether to dump the graph after + every graph transformation. (default False) + + Example usage: + + # Converting a GraphDef from session. + converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) + + # Converting a GraphDef from file. + converter = lite.TocoConverter.from_frozen_graph( + graph_def_file, input_arrays, output_arrays) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) + + # Converting a SavedModel. + converter = lite.TocoConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + """ + + def __init__(self, graph_def, input_tensors, output_tensors): + """Constructor for TocoConverter. + + Args: + + graph_def: TensorFlow GraphDef. + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + """ + self._graph_def = graph_def + self._input_tensors = input_tensors + self._output_tensors = output_tensors + self.inference_type = constants.FLOAT + self.inference_input_type = None + self.output_format = constants.TFLITE + self.quantized_input_stats = {} + self.default_ranges_stats = None + self.drop_control_dependency = True + self.reorder_across_fake_quant = False + self.change_concat_input_ranges = False + self.allow_custom_ops = False + self.quantize_weights = False + self.dump_graphviz_dir = None + self.dump_graphviz_video = False + + @classmethod + def from_session(cls, sess, input_tensors, output_tensors): + """Creates a TocoConverter class from a TensorFlow Session. + + Args: + sess: TensorFlow Session. + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + + Returns: + TocoConverter class. + """ + graph_def = _freeze_graph(sess, output_tensors) + return cls(graph_def, input_tensors, output_tensors) + + @classmethod + def from_frozen_graph(cls, + graph_def_file, + input_arrays, + output_arrays, + input_shapes=None): + """Creates a TocoConverter class from a file containing a frozen GraphDef. + + Args: + graph_def_file: Full filepath of file containing TensorFlow GraphDef. + input_arrays: List of input tensors to freeze graph with. + output_arrays: List of output tensors to freeze graph with. + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) + + Returns: + TocoConverter class. + + Raises: + ValueError: + Unable to parse input file. + The graph is not frozen. + input_arrays or output_arrays contains an invalid tensor name. + """ + with _session.Session() as sess: + sess.run(global_variables_initializer()) + + # Read GraphDef from file. + graph_def = _graph_pb2.GraphDef() + with open(graph_def_file, "rb") as f: + file_content = f.read() + try: + graph_def.ParseFromString(file_content) + except (_text_format.ParseError, DecodeError): + try: + print("Ignore 'tcmalloc: large alloc' warnings.") + + if not isinstance(file_content, str): + if PY3: + file_content = file_content.decode('utf-8') + else: + file_content = file_content.encode('utf-8') + _text_format.Merge(file_content, graph_def) + except (_text_format.ParseError, DecodeError): + raise ValueError( + "Unable to parse input file '{}'.".format(graph_def_file)) + sess.graph.as_default() + import_graph_def(graph_def, name="") + + # Get input and output tensors. + input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) + output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) + set_tensor_shapes(input_tensors, input_shapes) + + # Check if graph is frozen. + if not _is_frozen_graph(sess): + raise ValueError("Please freeze the graph using freeze_graph.py.") + + # Create TocoConverter class. + return cls(sess.graph_def, input_tensors, output_tensors) + + @classmethod + def from_saved_model(cls, + saved_model_dir, + input_arrays=None, + input_shapes=None, + output_arrays=None, + tag_set=None, + signature_key=None): + """Creates a TocoConverter class from a SavedModel. + + Args: + saved_model_dir: SavedModel directory to convert. + input_arrays: List of input tensors to freeze graph with. Uses input + arrays from SignatureDef when none are provided. (default None) + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) + output_arrays: List of output tensors to freeze graph with. Uses output + arrays from SignatureDef when none are provided. (default None) + tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to + analyze. All tags in the tag set must be present. (default set("serve")) + signature_key: Key identifying SignatureDef containing inputs and outputs. + (default DEFAULT_SERVING_SIGNATURE_DEF_KEY) + + Returns: + TocoConverter class. + """ + if tag_set is None: + tag_set = set([tag_constants.SERVING]) + if signature_key is None: + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + + result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes, + output_arrays, tag_set, signature_key) + return cls( + graph_def=result[0], input_tensors=result[1], output_tensors=result[2]) + + def convert(self): + """Converts a TensorFlow GraphDef based on instance variables. + + Returns: + The converted data in serialized format. Either a TFLite Flatbuffer or a + Graphviz graph depending on value in `output_format`. + + Raises: + ValueError: + Input shape is not specified. + None value for dimension in input_tensor. + """ + # Checks dimensions in input tensor. + for tensor in self._input_tensors: + if not tensor.get_shape(): + raise ValueError("Provide an input shape for input array '{0}'.".format( + tensor_name(tensor))) + shape = tensor.get_shape().as_list() + if None in shape[1:]: + raise ValueError( + "None is only supported in the 1st dimension. Tensor '{0}' has " + "invalid shape '{1}'.".format(tensor_name(tensor), shape)) + elif shape[0] is None: + self._set_batch_size(batch_size=1) + + # Get quantization stats. Ensures there is one stat per name if the stats + # are specified. + if self.quantized_input_stats: + quantized_stats = [] + invalid_stats = [] + for tensor in self._input_tensors: + name = tensor_name(tensor) + if name in self.quantized_input_stats: + quantized_stats.append(self.quantized_input_stats[name]) + else: + invalid_stats.append(name) + + if invalid_stats: + raise ValueError("Quantization input stats are not available for input " + "tensors '{0}'.".format(",".join(invalid_stats))) + else: + quantized_stats = None + + # Converts model. + result = toco_convert( + input_data=self._graph_def, + input_tensors=self._input_tensors, + output_tensors=self._output_tensors, + inference_type=self.inference_type, + inference_input_type=self.inference_input_type, + input_format=constants.TENSORFLOW_GRAPHDEF, + output_format=self.output_format, + quantized_input_stats=quantized_stats, + default_ranges_stats=self.default_ranges_stats, + drop_control_dependency=self.drop_control_dependency, + reorder_across_fake_quant=self.reorder_across_fake_quant, + change_concat_input_ranges=self.change_concat_input_ranges, + allow_custom_ops=self.allow_custom_ops, + quantize_weights=self.quantize_weights, + dump_graphviz_dir=self.dump_graphviz_dir, + dump_graphviz_video=self.dump_graphviz_video) + return result + + def get_input_arrays(self): + """Returns a list of the names of the input tensors. + + Returns: + List of strings. + """ + return [tensor_name(tensor) for tensor in self._input_tensors] + + def _set_batch_size(self, batch_size): + """Sets the first dimension of the input tensor to `batch_size`. + + Args: + batch_size: Batch size for the model. Replaces the first dimension of an + input size array if undefined. (default 1) + """ + for tensor in self._input_tensors: + shape = tensor.get_shape().as_list() + shape[0] = batch_size + tensor.set_shape(shape) + + +def _is_frozen_graph(sess): + """Determines if the graph is frozen. + + Determines if a graph has previously been frozen by checking for any + operations of type Variable*. If variables are found, the graph is not frozen. + + Args: + sess: TensorFlow Session. + + Returns: + Bool. + """ + for op in sess.graph.get_operations(): + if op.type.startswith("Variable"): + return False + return True + + +def _freeze_graph(sess, output_tensors): + """Returns a frozen GraphDef. + + Freezes a graph with Variables in it. Otherwise the existing GraphDef is + returned. + + Args: + sess: TensorFlow Session. + output_tensors: List of output tensors (only .name is used from this). + + Returns: + Frozen GraphDef. + """ + if not _is_frozen_graph(sess): + sess.run(global_variables_initializer()) + output_arrays = [tensor_name(tensor) for tensor in output_tensors] + return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def, + output_arrays) + else: + return sess.graph_def diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8c9d2c1651dd2d0b3cd27cf638c04429e3131efb --- /dev/null +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -0,0 +1,622 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lite.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.python import lite_constants +from tensorflow.contrib.lite.python.interpreter import Interpreter +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import saved_model +from tensorflow.python.training.training_util import write_graph + + +class FromSessionTest(test_util.TensorFlowTestCase): + + def testFloat(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testQuantization(self): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + out_tensor = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = { + 'inputA': (0., 1.), + 'inputB': (0., 1.) + } # mean, std_dev + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((1., 0.), + input_details[0]['quantization']) # scale, zero_point + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.uint8, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((1., 0.), + input_details[1]['quantization']) # scale, zero_point + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('output', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + + def testQuantizationInvalid(self): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + out_tensor = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual( + 'Quantization input stats are not available for input tensors ' + '\'inputB\'.', str(error.exception)) + + def testSizeNoneInvalid(self): + in_tensor = array_ops.placeholder(dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Test invalid shape. None after 1st dimension. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual('Provide an input shape for input array \'Placeholder\'.', + str(error.exception)) + + def testBatchSizeInvalid(self): + in_tensor = array_ops.placeholder( + shape=[1, None, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Test invalid shape. None after 1st dimension. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual( + 'None is only supported in the 1st dimension. Tensor ' + '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.', + str(error.exception)) + + def testBatchSizeValid(self): + in_tensor = array_ops.placeholder( + shape=[None, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFreezeGraph(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + var = variable_scope.get_variable( + 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + var + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + # TODO(nupurgarg): Verify value of contents in GraphViz. + def testGraphviz(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.output_format = lite_constants.GRAPHVIZ_DOT + graphviz_output = converter.convert() + self.assertTrue(graphviz_output) + + # TODO(nupurgarg): Verify value of contents in GraphViz. + def testDumpGraphviz(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure interpreter is able to allocate and check graphviz data. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + num_items_graphviz = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + converter.dump_graphviz_video = True + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure graphviz folder has more data after using video flag. + num_items_graphviz_video = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz_video > num_items_graphviz) + + def testInferenceInputType(self): + in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_input_type = lite_constants.QUANTIZED_UINT8 + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + def testDefaultRangesStats(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev + converter.default_ranges_stats = (0, 6) # min, max + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((1., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + + def testQuantizeWeights(self): + np.random.seed(0) + # We need the tensor to have more than 1024 elements for quantize_weights + # to kick in. Thus, the [33, 33] shape. + in_tensor_1 = array_ops.placeholder( + shape=[33, 33], dtype=dtypes.float32, name='inputA') + in_tensor_2 = constant_op.constant( + np.random.uniform(low=-10., high=10., size=(33, 33)), + shape=[33, 33], + dtype=dtypes.float32, + name='inputB') + out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') + sess = session.Session() + + # Convert float model. + float_converter = lite.TocoConverter.from_session(sess, [in_tensor_1], + [out_tensor]) + float_tflite = float_converter.convert() + self.assertTrue(float_tflite) + + # Convert quantized weights model. + quantized_weights_converter = lite.TocoConverter.from_session( + sess, [in_tensor_1], [out_tensor]) + quantized_weights_converter.quantize_weights = True + quantized_weights_tflite = quantized_weights_converter.convert() + self.assertTrue(quantized_weights_tflite) + + # Ensure that the quantized weights tflite model is smaller. + self.assertTrue(len(quantized_weights_tflite) < len(float_tflite)) + + +class FromFrozenGraphFile(test_util.TensorFlowTestCase): + + def testFloat(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph(graph_def_file, + ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFloatWithShapesArray(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph( + graph_def_file, ['Placeholder'], ['add'], + input_shapes={'Placeholder': [1, 16, 16, 3]}) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + + def testFreezeGraph(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + var = variable_scope.get_variable( + 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + var + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Ensure the graph with variables cannot be converted. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual('Please freeze the graph using freeze_graph.py.', + str(error.exception)) + + def testPbtxt(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt') + write_graph(sess.graph_def, '', graph_def_file, True) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph(graph_def_file, + ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testInvalidFile(self): + graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') + with gfile.Open(graph_def_file, 'wb') as temp_file: + temp_file.write('bad data') + temp_file.flush() + + # Attempts to convert the invalid model. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual( + 'Unable to parse input file \'{}\'.'.format(graph_def_file), + str(error.exception)) + + +class FromSavedModelTest(test_util.TensorFlowTestCase): + + def _createSavedModel(self, shape): + """Create a simple SavedModel.""" + saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel') + with session.Session() as sess: + in_tensor_1 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name='inputB') + in_tensor_2 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name='inputA') + out_tensor = in_tensor_1 + in_tensor_2 + inputs = {'x': in_tensor_1, 'y': in_tensor_2} + outputs = {'z': out_tensor} + saved_model.simple_save(sess, saved_model_dir, inputs, outputs) + return saved_model_dir + + def testSimpleModel(self): + """Test a SavedModel.""" + saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testNoneBatchSize(self): + """Test a SavedModel, with None in input tensor's shape.""" + saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3]) + + converter = lite.TocoConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testOrderInputArrays(self): + """Test a SavedModel ordering of input arrays.""" + saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) + + converter = lite.TocoConverter.from_saved_model( + saved_model_dir, input_arrays=['inputB', 'inputA']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testSubsetInputArrays(self): + """Test a SavedModel with a subset of the input array names of the model.""" + saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) + + # Check case where input shape is given. + converter = lite.TocoConverter.from_saved_model( + saved_model_dir, + input_arrays=['inputA'], + input_shapes={'inputA': [1, 16, 16, 3]}) + + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check case where input shape is None. + converter = lite.TocoConverter.from_saved_model( + saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None}) + + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..7bbfe2a6015925555a665ed6eabce9d7084454ca --- /dev/null +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -0,0 +1,364 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python command line interface for running TOCO.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys + +from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 +from tensorflow.python.platform import app + + +def _parse_array(values): + if values: + return values.split(",") + + +def _parse_int_array(values): + if values: + return [int(val) for val in values.split(",")] + + +def _parse_set(values): + if values: + return set(values.split(",")) + + +def _get_toco_converter(flags): + """Makes a TocoConverter object based on the flags provided. + + Args: + flags: argparse.Namespace object containing TFLite flags. + + Returns: + TocoConverter object. + """ + # Parse input and output arrays. + input_arrays = _parse_array(flags.input_arrays) + input_shapes = None + if flags.input_shapes: + input_shapes_list = [ + _parse_int_array(shape) for shape in flags.input_shapes.split(":") + ] + input_shapes = dict(zip(input_arrays, input_shapes_list)) + output_arrays = _parse_array(flags.output_arrays) + + converter_kwargs = { + "input_arrays": input_arrays, + "input_shapes": input_shapes, + "output_arrays": output_arrays + } + + # Create TocoConverter. + if flags.graph_def_file: + converter_fn = lite.TocoConverter.from_frozen_graph + converter_kwargs["graph_def_file"] = flags.graph_def_file + elif flags.saved_model_dir: + converter_fn = lite.TocoConverter.from_saved_model + converter_kwargs["saved_model_dir"] = flags.saved_model_dir + converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) + converter_kwargs["signature_key"] = flags.saved_model_signature_key + + return converter_fn(**converter_kwargs) + + +def _convert_model(flags): + """Calls function to convert the TensorFlow model into a TFLite model. + + Args: + flags: argparse.Namespace object. + + Raises: + ValueError: Invalid flags. + """ + # Create converter. + converter = _get_toco_converter(flags) + if flags.inference_type: + converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type) + if flags.inference_input_type: + converter.inference_input_type = _types_pb2.IODataType.Value( + flags.inference_input_type) + if flags.output_format: + converter.output_format = _toco_flags_pb2.FileFormat.Value( + flags.output_format) + + if flags.mean_values and flags.std_dev_values: + input_arrays = converter.get_input_arrays() + std_dev_values = _parse_int_array(flags.std_dev_values) + mean_values = _parse_int_array(flags.mean_values) + quant_stats = zip(mean_values, std_dev_values) + if ((not flags.input_arrays and len(input_arrays) > 1) or + (len(input_arrays) != len(quant_stats))): + raise ValueError("Mismatching --input_arrays, --std_dev_values, and " + "--mean_values. The flags must have the same number of " + "items. The current input arrays are '{0}'. " + "--input_arrays must be present when specifying " + "--std_dev_values and --mean_values with multiple input " + "tensors in order to map between names and " + "values.".format(",".join(input_arrays))) + converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) + if (flags.default_ranges_min is not None) and (flags.default_ranges_max is + not None): + converter.default_ranges_stats = (flags.default_ranges_min, + flags.default_ranges_max) + + if flags.drop_control_dependency: + converter.drop_control_dependency = flags.drop_control_dependency + if flags.reorder_across_fake_quant: + converter.reorder_across_fake_quant = flags.reorder_across_fake_quant + if flags.change_concat_input_ranges: + converter.change_concat_input_ranges = flags.change_concat_input_ranges + if flags.allow_custom_ops: + converter.allow_custom_ops = flags.allow_custom_ops + if flags.quantize_weights: + converter.quantize_weights = flags.quantize_weights + if flags.dump_graphviz_dir: + converter.dump_graphviz_dir = flags.dump_graphviz_dir + if flags.dump_graphviz_video: + converter.dump_graphviz_vode = flags.dump_graphviz_video + + # Convert model. + output_data = converter.convert() + with open(flags.output_file, "wb") as f: + f.write(output_data) + + +def _check_flags(flags, unparsed): + """Checks the parsed and unparsed flags to ensure they are valid. + + Raises an error if previously support unparsed flags are found. Raises an + error for parsed flags that don't meet the required conditions. + + Args: + flags: argparse.Namespace object containing TFLite flags. + unparsed: List of unparsed flags. + + Raises: + ValueError: Invalid flags. + """ + + # Check unparsed flags for common mistakes based on previous TOCO. + def _get_message_unparsed(flag, orig_flag, new_flag): + if flag.startswith(orig_flag): + return "\n Use {0} instead of {1}".format(new_flag, orig_flag) + return "" + + if unparsed: + output = "" + for flag in unparsed: + output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") + output += _get_message_unparsed(flag, "--savedmodel_directory", + "--saved_model_dir") + output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") + output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") + output += _get_message_unparsed(flag, "--dump_graphviz", + "--dump_graphviz_dir") + if output: + raise ValueError(output) + + # Check that flags are valid. + if flags.graph_def_file and (not flags.input_arrays or + not flags.output_arrays): + raise ValueError("--input_arrays and --output_arrays are required with " + "--graph_def_file") + + if flags.input_shapes: + if not flags.input_arrays: + raise ValueError("--input_shapes must be used with --input_arrays") + if flags.input_shapes.count(":") != flags.input_arrays.count(","): + raise ValueError("--input_shapes and --input_arrays must have the same " + "number of items") + + if flags.std_dev_values or flags.mean_values: + if bool(flags.std_dev_values) != bool(flags.mean_values): + raise ValueError("--std_dev_values and --mean_values must be used " + "together") + if flags.std_dev_values.count(",") != flags.mean_values.count(","): + raise ValueError("--std_dev_values, --mean_values must have the same " + "number of items") + + if (flags.default_ranges_min is None) != (flags.default_ranges_max is None): + raise ValueError("--default_ranges_min and --default_ranges_max must be " + "used together") + + +def run_main(_): + """Main in toco_convert.py.""" + parser = argparse.ArgumentParser( + description=("Command line tool to run TensorFlow Lite Optimizing " + "Converter (TOCO).")) + + # Output file flag. + parser.add_argument( + "--output_file", + type=str, + help="Full filepath of the output file.", + required=True) + + # Input file flags. + input_file_group = parser.add_mutually_exclusive_group(required=True) + input_file_group.add_argument( + "--graph_def_file", + type=str, + help="Full filepath of file containing TensorFlow GraphDef.") + input_file_group.add_argument( + "--saved_model_dir", + type=str, + help="Full filepath of directory containing the SavedModel.") + + # Model format flags. + parser.add_argument( + "--output_format", + type=str.upper, + choices=["TFLITE", "GRAPHVIZ_DOT"], + help="Output file format.") + parser.add_argument( + "--inference_type", + type=str.upper, + choices=["FLOAT", "QUANTIZED_UINT8", "STRING"], + help="Target data type of arrays in the output file.") + parser.add_argument( + "--inference_input_type", + type=str.upper, + choices=["FLOAT", "QUANTIZED_UINT8", "STRING"], + help=("Target data type of input arrays. Allows for a different type for " + "input arrays in the case of quantization.")) + + # Input and output arrays flags. + parser.add_argument( + "--input_arrays", + type=str, + help="Names of the output arrays, comma-separated.") + parser.add_argument( + "--input_shapes", + type=str, + help="Shapes corresponding to --input_arrays, colon-separated.") + parser.add_argument( + "--output_arrays", + type=str, + help="Names of the output arrays, comma-separated.") + + # SavedModel related flags. + parser.add_argument( + "--saved_model_tag_set", + type=str, + help=("Comma-separated set of tags identifying the MetaGraphDef within " + "the SavedModel to analyze. All tags must be present. " + "(default \"serve\")")) + parser.add_argument( + "--saved_model_signature_key", + type=str, + help=("Key identifying the SignatureDef containing inputs and outputs. " + "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) + + # Quantization flags. + parser.add_argument( + "--std_dev_values", + type=str, + help=("Standard deviation of training data for each input tensor, " + "comma-separated. Used for quantization. (default None)")) + parser.add_argument( + "--mean_values", + type=str, + help=("Mean of training data for each input tensor, comma-separated. " + "Used for quantization. (default None)")) + parser.add_argument( + "--default_ranges_min", + type=int, + help=("Default value for min bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) + parser.add_argument( + "--default_ranges_max", + type=int, + help=("Default value for max bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) + parser.add_argument( + "--quantize_weights", + type=bool, + help=("Store float weights as quantized weights followed by dequantize " + "operations. Inference is still done in FLOAT, but reduces model " + "size (at the cost of accuracy and latency).")) + + # Graph manipulation flags. + parser.add_argument( + "--drop_control_dependency", + action="store_true", + help=("Boolean indicating whether to drop control dependencies silently. " + "This is due to TensorFlow not supporting control dependencies. " + "(default True)")) + parser.add_argument( + "--reorder_across_fake_quant", + action="store_true", + help=("Boolean indicating whether to reorder FakeQuant nodes in " + "unexpected locations. Used when the location of the FakeQuant " + "nodes is preventing graph transformations necessary to convert " + "the graph. Results in a graph that differs from the quantized " + "training graph, potentially causing differing arithmetic " + "behavior. (default False)")) + parser.add_argument( + "--change_concat_input_ranges", + action="store_true", + help=("Boolean to change behavior of min/max ranges for inputs and " + "outputs of the concat operator for quantized models. Changes the " + "ranges of concat operator overlap when true. (default False)")) + parser.add_argument( + "--allow_custom_ops", + action="store_true", + help=("Boolean indicating whether to allow custom operations. When false " + "any unknown operation is an error. When true, custom ops are " + "created for any op that is unknown. The developer will need to " + "provide these to the TensorFlow Lite runtime with a custom " + "resolver. (default False)")) + + # Logging flags. + parser.add_argument( + "--dump_graphviz_dir", + type=str, + help=("Full filepath of folder to dump the graphs at various stages of " + "processing GraphViz .dot files. Preferred over --output_format=" + "GRAPHVIZ_DOT in order to keep the requirements of the output " + "file.")) + parser.add_argument( + "--dump_graphviz_video", + action="store_true", + help=("Boolean indicating whether to dump the graph after every graph " + "transformation")) + + tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) + try: + _check_flags(tflite_flags, unparsed) + except ValueError as e: + parser.print_usage() + file_name = os.path.basename(sys.argv[0]) + sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e))) + sys.exit(1) + _convert_model(tflite_flags) + + +def main(): + app.run(main=run_main, argv=sys.argv[:1]) + + +if __name__ == "__main__": + main() diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc index ac408d2f94b98d505afe4c951d7cc2ff960606fb..9dc8daa227dd68ccde2efa4013ac4465a72e6bb0 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc @@ -39,7 +39,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by -// `schema_builtin_ops_header_generator.py`. +// `schema/builtin_ops_header/generator.cc`. #ifdef __cplusplus extern "C" { @@ -57,7 +57,6 @@ const char* kFileFooter = } // extern "C" #endif // __cplusplus #endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ -} )"; } // anonymous namespace diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 2f5c39e7d72197af858b92632065aed3d2caa642..ee5208df1456d01f1a99ecc69722f5fb4ab0763a 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -142,6 +142,15 @@ enum BuiltinOperator : byte { GREATER_EQUAL = 62, LESS_EQUAL = 63, SELECT = 64, + SLICE = 65, + SIN = 66, + TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, } // Options for the builtin operators. @@ -193,6 +202,13 @@ union BuiltinOptions { GreaterEqualOptions, LessEqualOptions, SelectOptions, + SliceOptions, + TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, } enum Padding : byte { SAME, VALID } @@ -304,11 +320,23 @@ table LocalResponseNormalizationOptions { beta:float; } +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + // An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell table LSTMOptions { + // Parameters for LSTM version 1 or above. fused_activation_function:ActivationFunctionType; cell_clip: float; // Optional, 0.0 means no clipping proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; } table ResizeBilinearOptions { @@ -414,6 +442,9 @@ table DequantizeOptions { table MaximumMinimumOptions { } +table TileOptions { +} + table ArgMaxOptions { output_type : TensorType; } @@ -436,11 +467,37 @@ table NegOptions { table SelectOptions { } +table SliceOptions { +} + +table TransposeConvOptions { + padding:Padding; + stride_w:int; + stride_h:int; +} + +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { builtin_code:BuiltinOperator; custom_code:string; + + // The version of the operator. The version need to be bumped whenever new + // parameters are introduced into an op. + version:int = 1; } enum CustomOptionsFormat : byte { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h old mode 100644 new mode 100755 index a2f0c8cdd28934217e1f641e8e7165cfae87f73a..887e47ed1ea309d025d4be8745ffb8da06e8ee6b --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -151,6 +151,9 @@ struct DequantizeOptionsT; struct MaximumMinimumOptions; struct MaximumMinimumOptionsT; +struct TileOptions; +struct TileOptionsT; + struct ArgMaxOptions; struct ArgMaxOptionsT; @@ -172,6 +175,24 @@ struct NegOptionsT; struct SelectOptions; struct SelectOptionsT; +struct SliceOptions; +struct SliceOptionsT; + +struct TransposeConvOptions; +struct TransposeConvOptionsT; + +struct ExpandDimsOptions; +struct ExpandDimsOptionsT; + +struct SparseToDenseOptions; +struct SparseToDenseOptionsT; + +struct EqualOptions; +struct EqualOptionsT; + +struct NotEqualOptions; +struct NotEqualOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -296,11 +317,20 @@ enum BuiltinOperator { BuiltinOperator_GREATER_EQUAL = 62, BuiltinOperator_LESS_EQUAL = 63, BuiltinOperator_SELECT = 64, + BuiltinOperator_SLICE = 65, + BuiltinOperator_SIN = 66, + BuiltinOperator_TRANSPOSE_CONV = 67, + BuiltinOperator_SPARSE_TO_DENSE = 68, + BuiltinOperator_TILE = 69, + BuiltinOperator_EXPAND_DIMS = 70, + BuiltinOperator_EQUAL = 71, + BuiltinOperator_NOT_EQUAL = 72, + BuiltinOperator_LOG = 73, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_SELECT + BuiltinOperator_MAX = BuiltinOperator_LOG }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[64] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[73] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -365,7 +395,16 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[64] { BuiltinOperator_GREATER, BuiltinOperator_GREATER_EQUAL, BuiltinOperator_LESS_EQUAL, - BuiltinOperator_SELECT + BuiltinOperator_SELECT, + BuiltinOperator_SLICE, + BuiltinOperator_SIN, + BuiltinOperator_TRANSPOSE_CONV, + BuiltinOperator_SPARSE_TO_DENSE, + BuiltinOperator_TILE, + BuiltinOperator_EXPAND_DIMS, + BuiltinOperator_EQUAL, + BuiltinOperator_NOT_EQUAL, + BuiltinOperator_LOG }; return values; } @@ -437,6 +476,15 @@ inline const char **EnumNamesBuiltinOperator() { "GREATER_EQUAL", "LESS_EQUAL", "SELECT", + "SLICE", + "SIN", + "TRANSPOSE_CONV", + "SPARSE_TO_DENSE", + "TILE", + "EXPAND_DIMS", + "EQUAL", + "NOT_EQUAL", + "LOG", nullptr }; return names; @@ -496,11 +544,18 @@ enum BuiltinOptions { BuiltinOptions_GreaterEqualOptions = 45, BuiltinOptions_LessEqualOptions = 46, BuiltinOptions_SelectOptions = 47, + BuiltinOptions_SliceOptions = 48, + BuiltinOptions_TransposeConvOptions = 49, + BuiltinOptions_SparseToDenseOptions = 50, + BuiltinOptions_TileOptions = 51, + BuiltinOptions_ExpandDimsOptions = 52, + BuiltinOptions_EqualOptions = 53, + BuiltinOptions_NotEqualOptions = 54, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_SelectOptions + BuiltinOptions_MAX = BuiltinOptions_NotEqualOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[48] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[55] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -549,7 +604,14 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[48] { BuiltinOptions_GreaterOptions, BuiltinOptions_GreaterEqualOptions, BuiltinOptions_LessEqualOptions, - BuiltinOptions_SelectOptions + BuiltinOptions_SelectOptions, + BuiltinOptions_SliceOptions, + BuiltinOptions_TransposeConvOptions, + BuiltinOptions_SparseToDenseOptions, + BuiltinOptions_TileOptions, + BuiltinOptions_ExpandDimsOptions, + BuiltinOptions_EqualOptions, + BuiltinOptions_NotEqualOptions }; return values; } @@ -604,6 +666,13 @@ inline const char **EnumNamesBuiltinOptions() { "GreaterEqualOptions", "LessEqualOptions", "SelectOptions", + "SliceOptions", + "TransposeConvOptions", + "SparseToDenseOptions", + "TileOptions", + "ExpandDimsOptions", + "EqualOptions", + "NotEqualOptions", nullptr }; return names; @@ -806,6 +875,34 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SelectOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SliceOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TransposeConvOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TileOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1213,6 +1310,62 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_SelectOptions ? reinterpret_cast(value) : nullptr; } + SliceOptionsT *AsSliceOptions() { + return type == BuiltinOptions_SliceOptions ? + reinterpret_cast(value) : nullptr; + } + const SliceOptionsT *AsSliceOptions() const { + return type == BuiltinOptions_SliceOptions ? + reinterpret_cast(value) : nullptr; + } + TransposeConvOptionsT *AsTransposeConvOptions() { + return type == BuiltinOptions_TransposeConvOptions ? + reinterpret_cast(value) : nullptr; + } + const TransposeConvOptionsT *AsTransposeConvOptions() const { + return type == BuiltinOptions_TransposeConvOptions ? + reinterpret_cast(value) : nullptr; + } + SparseToDenseOptionsT *AsSparseToDenseOptions() { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } + const SparseToDenseOptionsT *AsSparseToDenseOptions() const { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } + TileOptionsT *AsTileOptions() { + return type == BuiltinOptions_TileOptions ? + reinterpret_cast(value) : nullptr; + } + const TileOptionsT *AsTileOptions() const { + return type == BuiltinOptions_TileOptions ? + reinterpret_cast(value) : nullptr; + } + ExpandDimsOptionsT *AsExpandDimsOptions() { + return type == BuiltinOptions_ExpandDimsOptions ? + reinterpret_cast(value) : nullptr; + } + const ExpandDimsOptionsT *AsExpandDimsOptions() const { + return type == BuiltinOptions_ExpandDimsOptions ? + reinterpret_cast(value) : nullptr; + } + EqualOptionsT *AsEqualOptions() { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast(value) : nullptr; + } + const EqualOptionsT *AsEqualOptions() const { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast(value) : nullptr; + } + NotEqualOptionsT *AsNotEqualOptions() { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast(value) : nullptr; + } + const NotEqualOptionsT *AsNotEqualOptions() const { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -1320,6 +1473,35 @@ inline const char *EnumNameLSHProjectionType(LSHProjectionType e) { return EnumNamesLSHProjectionType()[index]; } +enum LSTMKernelType { + LSTMKernelType_FULL = 0, + LSTMKernelType_BASIC = 1, + LSTMKernelType_MIN = LSTMKernelType_FULL, + LSTMKernelType_MAX = LSTMKernelType_BASIC +}; + +inline LSTMKernelType (&EnumValuesLSTMKernelType())[2] { + static LSTMKernelType values[] = { + LSTMKernelType_FULL, + LSTMKernelType_BASIC + }; + return values; +} + +inline const char **EnumNamesLSTMKernelType() { + static const char *names[] = { + "FULL", + "BASIC", + nullptr + }; + return names; +} + +inline const char *EnumNameLSTMKernelType(LSTMKernelType e) { + const size_t index = static_cast(e); + return EnumNamesLSTMKernelType()[index]; +} + enum CombinerType { CombinerType_SUM = 0, CombinerType_MEAN = 1, @@ -2757,10 +2939,12 @@ struct LSTMOptionsT : public flatbuffers::NativeTable { ActivationFunctionType fused_activation_function; float cell_clip; float proj_clip; + LSTMKernelType kernel_type; LSTMOptionsT() : fused_activation_function(ActivationFunctionType_NONE), cell_clip(0.0f), - proj_clip(0.0f) { + proj_clip(0.0f), + kernel_type(LSTMKernelType_FULL) { } }; @@ -2769,7 +2953,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, - VT_PROJ_CLIP = 8 + VT_PROJ_CLIP = 8, + VT_KERNEL_TYPE = 10 }; ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -2780,11 +2965,15 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { float proj_clip() const { return GetField(VT_PROJ_CLIP, 0.0f); } + LSTMKernelType kernel_type() const { + return static_cast(GetField(VT_KERNEL_TYPE, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_CELL_CLIP) && VerifyField(verifier, VT_PROJ_CLIP) && + VerifyField(verifier, VT_KERNEL_TYPE) && verifier.EndTable(); } LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2804,6 +2993,9 @@ struct LSTMOptionsBuilder { void add_proj_clip(float proj_clip) { fbb_.AddElement(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); } + void add_kernel_type(LSTMKernelType kernel_type) { + fbb_.AddElement(LSTMOptions::VT_KERNEL_TYPE, static_cast(kernel_type), 0); + } explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2820,10 +3012,12 @@ inline flatbuffers::Offset CreateLSTMOptions( flatbuffers::FlatBufferBuilder &_fbb, ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, float cell_clip = 0.0f, - float proj_clip = 0.0f) { + float proj_clip = 0.0f, + LSTMKernelType kernel_type = LSTMKernelType_FULL) { LSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_kernel_type(kernel_type); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -4086,6 +4280,46 @@ inline flatbuffers::Offset CreateMaximumMinimumOptions( flatbuffers::Offset CreateMaximumMinimumOptions(flatbuffers::FlatBufferBuilder &_fbb, const MaximumMinimumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct TileOptionsT : public flatbuffers::NativeTable { + typedef TileOptions TableType; + TileOptionsT() { + } +}; + +struct TileOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TileOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + TileOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TileOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit TileOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TileOptionsBuilder &operator=(const TileOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTileOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + TileOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateTileOptions(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct ArgMaxOptionsT : public flatbuffers::NativeTable { typedef ArgMaxOptions TableType; TensorType output_type; @@ -4380,12 +4614,306 @@ inline flatbuffers::Offset CreateSelectOptions( flatbuffers::Offset CreateSelectOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct SliceOptionsT : public flatbuffers::NativeTable { + typedef SliceOptions TableType; + SliceOptionsT() { + } +}; + +struct SliceOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SliceOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + SliceOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SliceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SliceOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit SliceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SliceOptionsBuilder &operator=(const SliceOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSliceOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + SliceOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSliceOptions(flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TransposeConvOptionsT : public flatbuffers::NativeTable { + typedef TransposeConvOptions TableType; + Padding padding; + int32_t stride_w; + int32_t stride_h; + TransposeConvOptionsT() + : padding(Padding_SAME), + stride_w(0), + stride_h(0) { + } +}; + +struct TransposeConvOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TransposeConvOptionsT NativeTableType; + enum { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8 + }; + Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING) && + VerifyField(verifier, VT_STRIDE_W) && + VerifyField(verifier, VT_STRIDE_H) && + verifier.EndTable(); + } + TransposeConvOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TransposeConvOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TransposeConvOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_padding(Padding padding) { + fbb_.AddElement(TransposeConvOptions::VT_PADDING, static_cast(padding), 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(TransposeConvOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(TransposeConvOptions::VT_STRIDE_H, stride_h, 0); + } + explicit TransposeConvOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TransposeConvOptionsBuilder &operator=(const TransposeConvOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTransposeConvOptions( + flatbuffers::FlatBufferBuilder &_fbb, + Padding padding = Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0) { + TransposeConvOptionsBuilder builder_(_fbb); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_padding(padding); + return builder_.Finish(); +} + +flatbuffers::Offset CreateTransposeConvOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ExpandDimsOptionsT : public flatbuffers::NativeTable { + typedef ExpandDimsOptions TableType; + ExpandDimsOptionsT() { + } +}; + +struct ExpandDimsOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ExpandDimsOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ExpandDimsOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExpandDimsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExpandDimsOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit ExpandDimsOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ExpandDimsOptionsBuilder &operator=(const ExpandDimsOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateExpandDimsOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + ExpandDimsOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateExpandDimsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SparseToDenseOptionsT : public flatbuffers::NativeTable { + typedef SparseToDenseOptions TableType; + bool validate_indices; + SparseToDenseOptionsT() + : validate_indices(false) { + } +}; + +struct SparseToDenseOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SparseToDenseOptionsT NativeTableType; + enum { + VT_VALIDATE_INDICES = 4 + }; + bool validate_indices() const { + return GetField(VT_VALIDATE_INDICES, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VALIDATE_INDICES) && + verifier.EndTable(); + } + SparseToDenseOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SparseToDenseOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_validate_indices(bool validate_indices) { + fbb_.AddElement(SparseToDenseOptions::VT_VALIDATE_INDICES, static_cast(validate_indices), 0); + } + explicit SparseToDenseOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SparseToDenseOptionsBuilder &operator=(const SparseToDenseOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSparseToDenseOptions( + flatbuffers::FlatBufferBuilder &_fbb, + bool validate_indices = false) { + SparseToDenseOptionsBuilder builder_(_fbb); + builder_.add_validate_indices(validate_indices); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct EqualOptionsT : public flatbuffers::NativeTable { + typedef EqualOptions TableType; + EqualOptionsT() { + } +}; + +struct EqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef EqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + EqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct EqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit EqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + EqualOptionsBuilder &operator=(const EqualOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + EqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct NotEqualOptionsT : public flatbuffers::NativeTable { + typedef NotEqualOptions TableType; + NotEqualOptionsT() { + } +}; + +struct NotEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef NotEqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NotEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NotEqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit NotEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + NotEqualOptionsBuilder &operator=(const NotEqualOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateNotEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + NotEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; std::string custom_code; + int32_t version; OperatorCodeT() - : builtin_code(BuiltinOperator_ADD) { + : builtin_code(BuiltinOperator_ADD), + version(1) { } }; @@ -4393,7 +4921,8 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef OperatorCodeT NativeTableType; enum { VT_BUILTIN_CODE = 4, - VT_CUSTOM_CODE = 6 + VT_CUSTOM_CODE = 6, + VT_VERSION = 8 }; BuiltinOperator builtin_code() const { return static_cast(GetField(VT_BUILTIN_CODE, 0)); @@ -4401,11 +4930,15 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::String *custom_code() const { return GetPointer(VT_CUSTOM_CODE); } + int32_t version() const { + return GetField(VT_VERSION, 1); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_BUILTIN_CODE) && VerifyOffset(verifier, VT_CUSTOM_CODE) && verifier.Verify(custom_code()) && + VerifyField(verifier, VT_VERSION) && verifier.EndTable(); } OperatorCodeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4422,6 +4955,9 @@ struct OperatorCodeBuilder { void add_custom_code(flatbuffers::Offset custom_code) { fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code); } + void add_version(int32_t version) { + fbb_.AddElement(OperatorCode::VT_VERSION, version, 1); + } explicit OperatorCodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -4437,8 +4973,10 @@ struct OperatorCodeBuilder { inline flatbuffers::Offset CreateOperatorCode( flatbuffers::FlatBufferBuilder &_fbb, BuiltinOperator builtin_code = BuiltinOperator_ADD, - flatbuffers::Offset custom_code = 0) { + flatbuffers::Offset custom_code = 0, + int32_t version = 1) { OperatorCodeBuilder builder_(_fbb); + builder_.add_version(version); builder_.add_custom_code(custom_code); builder_.add_builtin_code(builtin_code); return builder_.Finish(); @@ -4447,11 +4985,13 @@ inline flatbuffers::Offset CreateOperatorCode( inline flatbuffers::Offset CreateOperatorCodeDirect( flatbuffers::FlatBufferBuilder &_fbb, BuiltinOperator builtin_code = BuiltinOperator_ADD, - const char *custom_code = nullptr) { + const char *custom_code = nullptr, + int32_t version = 1) { return tflite::CreateOperatorCode( _fbb, builtin_code, - custom_code ? _fbb.CreateString(custom_code) : 0); + custom_code ? _fbb.CreateString(custom_code) : 0, + version); } flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -4638,6 +5178,27 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const SelectOptions *builtin_options_as_SelectOptions() const { return builtin_options_type() == BuiltinOptions_SelectOptions ? static_cast(builtin_options()) : nullptr; } + const SliceOptions *builtin_options_as_SliceOptions() const { + return builtin_options_type() == BuiltinOptions_SliceOptions ? static_cast(builtin_options()) : nullptr; + } + const TransposeConvOptions *builtin_options_as_TransposeConvOptions() const { + return builtin_options_type() == BuiltinOptions_TransposeConvOptions ? static_cast(builtin_options()) : nullptr; + } + const SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const { + return builtin_options_type() == BuiltinOptions_SparseToDenseOptions ? static_cast(builtin_options()) : nullptr; + } + const TileOptions *builtin_options_as_TileOptions() const { + return builtin_options_type() == BuiltinOptions_TileOptions ? static_cast(builtin_options()) : nullptr; + } + const ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const { + return builtin_options_type() == BuiltinOptions_ExpandDimsOptions ? static_cast(builtin_options()) : nullptr; + } + const EqualOptions *builtin_options_as_EqualOptions() const { + return builtin_options_type() == BuiltinOptions_EqualOptions ? static_cast(builtin_options()) : nullptr; + } + const NotEqualOptions *builtin_options_as_NotEqualOptions() const { + return builtin_options_type() == BuiltinOptions_NotEqualOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -4852,6 +5413,34 @@ template<> inline const SelectOptions *Operator::builtin_options_as inline const SliceOptions *Operator::builtin_options_as() const { + return builtin_options_as_SliceOptions(); +} + +template<> inline const TransposeConvOptions *Operator::builtin_options_as() const { + return builtin_options_as_TransposeConvOptions(); +} + +template<> inline const SparseToDenseOptions *Operator::builtin_options_as() const { + return builtin_options_as_SparseToDenseOptions(); +} + +template<> inline const TileOptions *Operator::builtin_options_as() const { + return builtin_options_as_TileOptions(); +} + +template<> inline const ExpandDimsOptions *Operator::builtin_options_as() const { + return builtin_options_as_ExpandDimsOptions(); +} + +template<> inline const EqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_EqualOptions(); +} + +template<> inline const NotEqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_NotEqualOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -5817,6 +6406,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_ { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; { auto _e = cell_clip(); _o->cell_clip = _e; }; { auto _e = proj_clip(); _o->proj_clip = _e; }; + { auto _e = kernel_type(); _o->kernel_type = _e; }; } inline flatbuffers::Offset LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -5830,11 +6420,13 @@ inline flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBuffe auto _fused_activation_function = _o->fused_activation_function; auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; + auto _kernel_type = _o->kernel_type; return tflite::CreateLSTMOptions( _fbb, _fused_activation_function, _cell_clip, - _proj_clip); + _proj_clip, + _kernel_type); } inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -6452,6 +7044,29 @@ inline flatbuffers::Offset CreateMaximumMinimumOptions(fl _fbb); } +inline TileOptionsT *TileOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new TileOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void TileOptions::UnPackTo(TileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset TileOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateTileOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateTileOptions(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TileOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateTileOptions( + _fbb); +} + inline ArgMaxOptionsT *ArgMaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ArgMaxOptionsT(); UnPackTo(_o, _resolver); @@ -6616,6 +7231,156 @@ inline flatbuffers::Offset CreateSelectOptions(flatbuffers::FlatB _fbb); } +inline SliceOptionsT *SliceOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SliceOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SliceOptions::UnPackTo(SliceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset SliceOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSliceOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSliceOptions(flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SliceOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSliceOptions( + _fbb); +} + +inline TransposeConvOptionsT *TransposeConvOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new TransposeConvOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void TransposeConvOptions::UnPackTo(TransposeConvOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = padding(); _o->padding = _e; }; + { auto _e = stride_w(); _o->stride_w = _e; }; + { auto _e = stride_h(); _o->stride_h = _e; }; +} + +inline flatbuffers::Offset TransposeConvOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateTransposeConvOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateTransposeConvOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TransposeConvOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _padding = _o->padding; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + return tflite::CreateTransposeConvOptions( + _fbb, + _padding, + _stride_w, + _stride_h); +} + +inline ExpandDimsOptionsT *ExpandDimsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ExpandDimsOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ExpandDimsOptions::UnPackTo(ExpandDimsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset ExpandDimsOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateExpandDimsOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateExpandDimsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ExpandDimsOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateExpandDimsOptions( + _fbb); +} + +inline SparseToDenseOptionsT *SparseToDenseOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SparseToDenseOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SparseToDenseOptions::UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = validate_indices(); _o->validate_indices = _e; }; +} + +inline flatbuffers::Offset SparseToDenseOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSparseToDenseOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SparseToDenseOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _validate_indices = _o->validate_indices; + return tflite::CreateSparseToDenseOptions( + _fbb, + _validate_indices); +} + +inline EqualOptionsT *EqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new EqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void EqualOptions::UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset EqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateEqualOptions( + _fbb); +} + +inline NotEqualOptionsT *NotEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new NotEqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void NotEqualOptions::UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset NotEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateNotEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NotEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNotEqualOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -6627,6 +7392,7 @@ inline void OperatorCode::UnPackTo(OperatorCodeT *_o, const flatbuffers::resolve (void)_resolver; { auto _e = builtin_code(); _o->builtin_code = _e; }; { auto _e = custom_code(); if (_e) _o->custom_code = _e->str(); }; + { auto _e = version(); _o->version = _e; }; } inline flatbuffers::Offset OperatorCode::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -6639,10 +7405,12 @@ inline flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBuf struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OperatorCodeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _builtin_code = _o->builtin_code; auto _custom_code = _o->custom_code.empty() ? 0 : _fbb.CreateString(_o->custom_code); + auto _version = _o->version; return tflite::CreateOperatorCode( _fbb, _builtin_code, - _custom_code); + _custom_code, + _version); } inline OperatorT *Operator::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -6987,6 +7755,34 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_SliceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TransposeConvOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -7193,6 +7989,34 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_SliceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_TransposeConvOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -7387,6 +8211,34 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateSelectOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_SliceOptions: { + auto ptr = reinterpret_cast(value); + return CreateSliceOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_TransposeConvOptions: { + auto ptr = reinterpret_cast(value); + return CreateTransposeConvOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + return CreateSparseToDenseOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(value); + return CreateTileOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(value); + return CreateExpandDimsOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -7581,6 +8433,34 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new SelectOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_SliceOptions: { + value = new SliceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_TransposeConvOptions: { + value = new TransposeConvOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SparseToDenseOptions: { + value = new SparseToDenseOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_TileOptions: { + value = new TileOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ExpandDimsOptions: { + value = new ExpandDimsOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_EqualOptions: { + value = new EqualOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_NotEqualOptions: { + value = new NotEqualOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -7823,6 +8703,41 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_SliceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_TransposeConvOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index f89c0d28d37b663ce8822297ea1ee0132eece5c9..80e4c5a4dde4702229887593afc5ffeef339176d 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -6,7 +6,8 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow/contrib/lite:build_def.bzl", - "gen_zipped_test_files", + "gen_zip_test", + "generated_test_models", ) load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load( @@ -14,57 +15,52 @@ load( "tf_cc_test", ) -gen_zipped_test_files( - name = "optest", - files = [ - "add.zip", - "arg_max.zip", - "avg_pool.zip", - "batch_to_space_nd.zip", - "concat.zip", - "constant.zip", - "control_dep.zip", - "conv.zip", - "depthwiseconv.zip", - "div.zip", - "exp.zip", - "floor.zip", - "fully_connected.zip", - "fused_batch_norm.zip", - "gather.zip", - "global_batch_norm.zip", - "greater.zip", - "greater_equal.zip", - "l2_pool.zip", - "l2norm.zip", - "less.zip", - "less_equal.zip", - "local_response_norm.zip", - "log_softmax.zip", - "max_pool.zip", - "maximum.zip", - "mean.zip", - "minimum.zip", - "mul.zip", - "neg.zip", - "pad.zip", - "padv2.zip", - "relu.zip", - "relu1.zip", - "relu6.zip", - "reshape.zip", - "resize_bilinear.zip", - "sigmoid.zip", - "softmax.zip", - "space_to_batch_nd.zip", - "space_to_depth.zip", - "split.zip", - "squeeze.zip", - "strided_slice.zip", - "sub.zip", - "topk.zip", - "transpose.zip", - "where.zip", +[gen_zip_test( + name = "zip_test_%s" % test_name, + size = "large", + srcs = ["generated_examples_zip_test.cc"], + args = [ + "--zip_file_path=$(location :zip_%s)" % test_name, + # TODO(angerson) We may be able to add an external unzip binary instead + # of relying on an existing one for OSS builds. + "--unzip_binary_path=/usr/bin/unzip", + ], + data = [ + ":zip_%s" % test_name, + ], + shard_count = 20, + tags = [ + "gen_zip_test", + "no_oss", + "tflite_not_portable", + ], + test_name = test_name, + deps = [ + ":parse_testdata_lib", + ":tflite_driver", + ":util", + "@com_google_googletest//:gtest", + "@com_googlesource_code_re2//:re2", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ] + select({ + "//conditions:default": [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_tensorflow_test_lib", + ], + }), +) for test_name in generated_test_models()] + +test_suite( + name = "generated_zip_tests", + tags = [ + "gen_zip_test", ], ) @@ -159,6 +155,7 @@ cc_library( deps = [ ":split", ":test_runner", + "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:builtin_ops", ], @@ -290,13 +287,9 @@ cc_library( deps = [ ":generate_testspec", ":parse_testdata_lib", - ":split", ":tflite_driver", - ":util", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string", - "//tensorflow/contrib/lite/kernels:builtin_ops", ], ) @@ -353,42 +346,4 @@ cc_binary( ], ) -tf_cc_test( - name = "generated_examples_zip_test", - size = "large", - srcs = ["generated_examples_zip_test.cc"], - args = [ - "--zip_files_dir=tensorflow/contrib/lite/testing/optest", - # TODO(angerson) We may be able to add an external unzip binary instead - # of relying on an existing one for OSS builds. - "--unzip_binary_path=/usr/bin/unzip", - ], - data = [":optest"], - shard_count = 20, - tags = [ - "no_oss", - "tflite_not_portable", - ], - deps = [ - ":parse_testdata_lib", - ":tflite_driver", - ":util", - "@com_google_googletest//:gtest", - "@com_googlesource_code_re2//:re2", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", - ] + select({ - "//conditions:default": [ - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib", - "//tensorflow/core:android_tensorflow_test_lib", - ], - }), -) - tflite_portable_test_suite() diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index f7cc7da900089ad0cde3fcb48e118b4929c439fe..f5e25784fa17209af7cfb06d32aeea2b9b947196 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -20,14 +20,21 @@ Usage: generate_examples bazel run //tensorflow/contrib/lite/testing:generate_examples + +To more easily debug failures use (or override) the --save_graphdefs flag to +place text proto graphdefs into the generated zip files. """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse +import functools import itertools +import operator import os +import random import re import sys import tempfile @@ -51,10 +58,11 @@ from tensorflow.python.ops import rnn parser = argparse.ArgumentParser(description="Script to generate TFLite tests.") parser.add_argument("output_path", help="Directory where the outputs will be go.") -parser.add_argument("--zip_to_output", - type=str, - help="Particular zip to output.", - required=False) +parser.add_argument( + "--zip_to_output", + type=str, + help="Particular zip to output.", + required=True) parser.add_argument("--toco", type=str, help="Path to toco tool.", @@ -90,21 +98,12 @@ KNOWN_BUGS = { r"fully_connected.*transpose_.=True": "67586970", # Softmax graphs are too complex. r"softmax.*dim=0": "67749831", - r"softmax.*input_shape=\[1,3,4,3\]": "67749831", - # SpaceToDepth only supports float32. - r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", # BatchToSpaceND only supports 4D tensors. r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", # Div will use floordiv. r"div.*int32": "72051395", - # TOCO require matching dimensions in strided_slice. - r"strided_slice.*begin=\[0\].*end=\[1\].*": "73170889", # No support for SplitV r"split.*num_or_size_splits=\[2,2\]": "73377559", - # Needs support for dimensions other than the last one in argmax. - r"arg_max.*axis=0.*": "77546240", - r"arg_max.*axis=1.*": "77546240", - r"arg_max.*axis=2.*": "77546240", } @@ -118,6 +117,8 @@ class ExtraTocoOptions(object): self.allow_custom_ops = False # Rnn states that are used to support rnn / lstm cells. self.rnn_states = None + # Split the LSTM inputs from 5 inoputs to 18 inputs for TFLite. + self.split_tflite_lstm_inputs = None def toco_options(data_types, @@ -146,14 +147,20 @@ def toco_options(data_types, " --inference_type=%s" % inference_type + " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" + " --input_arrays=%s" % ",".join(input_arrays) + - " --input_shapes=%s" % shape_str + " --output_arrays=%s" % ",".join(output_arrays)) + if shape_str: + s += (" --input_shapes=%s" % shape_str) if extra_toco_options.drop_control_dependency: s += " --drop_control_dependency" if extra_toco_options.allow_custom_ops: s += " --allow_custom_ops" if extra_toco_options.rnn_states: s += (" --rnn_states='" + extra_toco_options.rnn_states + "'") + if extra_toco_options.split_tflite_lstm_inputs is not None: + if extra_toco_options.split_tflite_lstm_inputs: + s += " --split_tflite_lstm_inputs=true" + else: + s += " --split_tflite_lstm_inputs=false" return s @@ -238,6 +245,19 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): return value.astype(dtype) +def create_scalar_data(dtype, min_value=-100, max_value=100): + """Build scalar tensor data range from min_value to max_value exclusively.""" + + if dtype in _TF_TYPE_INFO: + dtype = _TF_TYPE_INFO[dtype][0] + + if dtype in (tf.float32, tf.float16): + value = (max_value - min_value) * np.random.random() + min_value + elif dtype in (tf.int32, tf.uint8, tf.int64): + value = np.random.randint(min_value, max_value + 1) + return np.array(value, dtype=dtype) + + def freeze_graph(session, outputs): """Freeze the current graph. @@ -328,6 +348,11 @@ def normalize_output_name(output_name): ":0") else output_name +# How many test cases we may have in a zip file. Too many test cases will +# slow down the test data generation process. +_MAX_TESTS_PER_ZIP = 500 + + def make_zip_of_tests(zip_path, test_parameters, make_graph, @@ -357,19 +382,39 @@ def make_zip_of_tests(zip_path, Raises: RuntimeError: if there are toco errors that can't be ignored. """ + parameter_count = 0 + for parameters in test_parameters: + parameter_count += functools.reduce( + operator.mul, [len(values) for values in parameters.values()]) + + if parameter_count > _MAX_TESTS_PER_ZIP: + raise RuntimeError( + "Too many parameter combinations for generating '%s'.\n" + "There are %d combinations while the upper limit is %d.\n" + "Having too many combinations will slow down the tests.\n" + "Please consider splitting the test into multiple functions.\n" + % (zip_path, parameter_count, _MAX_TESTS_PER_ZIP)) # TODO(aselle): Make this allow multiple inputs outputs. archive = zipfile.PyZipFile(zip_path, "w") zip_manifest = [] convert_report = [] toco_errors = 0 + + processed_labels = set() for parameters in test_parameters: keys = parameters.keys() for curr in itertools.product(*parameters.values()): - label = zip_path.replace(".zip", "") + (",".join( + label = zip_path.replace(".zip", "_") + (",".join( "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", "")) if label[0] == "/": label = label[1:] + if label in processed_labels: + # Do not populate data for the same label more than once. It will cause + # errors when unzipping. + continue + processed_labels.add(label) + param_dict = dict(zip(keys, curr)) def build_example(label, param_dict_real): @@ -422,6 +467,11 @@ def make_zip_of_tests(zip_path, sess, tf.global_variables() + inputs + outputs) if use_frozen_graph else sess.graph_def + + if "split_tflite_lstm_inputs" in param_dict_real: + extra_toco_options.split_tflite_lstm_inputs = param_dict_real[ + "split_tflite_lstm_inputs"] + tflite_model_binary, toco_log = toco_convert( graph_def.SerializeToString(), input_tensors, output_tensors, extra_toco_options) @@ -430,7 +480,7 @@ def make_zip_of_tests(zip_path, report["toco_log"] = toco_log if FLAGS.save_graphdefs: - archive.writestr(label + ".pb", + archive.writestr(label + ".pbtxt", text_format.MessageToString(graph_def), zipfile.ZIP_DEFLATED) @@ -468,6 +518,7 @@ def make_zip_of_tests(zip_path, report["toco_log"]) convert_report.append((param_dict, report)) + report_io = StringIO() report_lib.make_report_table(report_io, zip_path, convert_report) archive.writestr("report.html", report_io.getvalue()) @@ -704,65 +755,83 @@ def make_binary_op_tests(zip_path, binary_operator): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_mean_tests(zip_path): - """Make a set of tests to do mean.""" +def make_reduce_tests(reduce_op): + """Make a set of tests to do reduce operation. - test_parameters = [{ - "input_dtype": [tf.float32, tf.int32, tf.int64], - "input_shape": [[3, 2, 4]], - "axis": [ - None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], - [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], - [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] - ], - "const_axis": [True, False], - "keepdims": [True, False], - }, { - "input_dtype": [tf.float32, tf.int32, tf.int64], - "input_shape": [[1, 224, 224, 3]], - "axis": [ - None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], - [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, - -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], - [2, 2, 3], [-3, -3, -4], [-3, 2, 1] - ], - "const_axis": [True, False], - "keepdims": [True, False], - }] + Args: + reduce_op: TensorFlow reduce operation to test, i.e. `tf.reduce_mean`. - def build_graph(parameters): - """Build the mean op testing graph.""" - input_tensor = tf.placeholder( - dtype=parameters["input_dtype"], - name="input", - shape=parameters["input_shape"]) + Returns: + a function representing the true generator with `reduce_op_in` curried. + """ - # Get axis as either a placeholder or constants. - if parameters["const_axis"]: - axis = parameters["axis"] - input_tensors = [input_tensor] - else: - if isinstance(parameters["axis"], list): - shape = [len(parameters["axis"])] + def f(zip_path): + """Actual function that generates examples.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape": [[3, 2, 4]], + "axis": [ + None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], + [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], + [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] + ], + "const_axis": [True, False], + "keepdims": [True, False], + }, { + "input_dtype": [tf.float32], + "input_shape": [[1, 8, 8, 3]], + "axis": [ + None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], + [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, + -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], + [2, 2, 3], [-3, -3, -4], [-3, 2, 1] + ], + "const_axis": [True, False], + "keepdims": [True, False], + }] + + def build_graph(parameters): + """Build the mean op testing graph.""" + input_tensor = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + + # Get axis as either a placeholder or constants. + if parameters["const_axis"]: + axis = parameters["axis"] + input_tensors = [input_tensor] else: - shape = [0] # shape for None or integers. - axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) - input_tensors = [input_tensor, axis] + if isinstance(parameters["axis"], list): + shape = [len(parameters["axis"])] + else: + shape = [0] # shape for None or integers. + axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) + input_tensors = [input_tensor, axis] - out = tf.reduce_mean( - input_tensor, axis=axis, keepdims=parameters["keepdims"]) - return input_tensors, [out] + out = reduce_op( + input_tensor, axis=axis, keepdims=parameters["keepdims"]) + return input_tensors, [out] - def build_inputs(parameters, sess, inputs, outputs): - values = [ - create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) - ] - if not parameters["const_axis"]: - if parameters["axis"]: - values.append(np.array(parameters["axis"])) - return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + def build_inputs(parameters, sess, inputs, outputs): + values = [ + create_tensor_data(parameters["input_dtype"], + parameters["input_shape"])] + if not parameters["const_axis"]: + if parameters["axis"]: + values.append(np.array(parameters["axis"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return f + + +def make_mean_tests(zip_path): + """Make a set of tests to do mean.""" + + return make_reduce_tests(tf.reduce_mean)(zip_path) def make_exp_tests(zip_path): @@ -1034,40 +1103,39 @@ def make_fused_batch_norm_tests(zip_path): def make_conv_tests(zip_path): """Make a set of tests to do convolution.""" - test_parameters = [ - { - "input_shape": [[1, 3, 4, 3]], - "filter_shape": [[1, 1, 3, 2]], - "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], - "dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]], - "padding": ["SAME", "VALID"], - "data_format": ["NHWC"], # TODO(aselle): NCHW would be good - "constant_filter": [True, False], - }, - { - "input_shape": [[2, 14, 14, 2]], - "filter_shape": [[6, 6, 2, 2]], - "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], - "dilations": [[1, 1, 1, 1], [1, 2, 2, 1]], - "padding": ["SAME", "VALID"], - "data_format": ["NHWC"], # TODO(aselle): NCHW would be good - "constant_filter": [True, False], - } - ] + test_parameters = [{ + "input_shape": [[1, 3, 4, 3], [4, 6, 6, 1]], + "filter_shape": [[1, 1], [2, 3], [3, 3]], + "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], + "dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + "constant_filter": [True, False], + "channel_multiplier": [1, 2], + }] + + def get_tensor_shapes(parameters): + input_shape = parameters["input_shape"] + filter_size = parameters["filter_shape"] + filter_shape = filter_size + [ + input_shape[3], parameters["channel_multiplier"] + ] + return [input_shape, filter_shape] def build_graph(parameters): """Build a conv graph given `parameters`.""" + input_shape, filter_shape = get_tensor_shapes(parameters) input_tensor = tf.placeholder( - dtype=tf.float32, name="input", shape=parameters["input_shape"]) + dtype=tf.float32, name="input", shape=input_shape) # Get filter input either as a placeholder or constants. Also get a list of # the input tensors that are represented as placeholders. if parameters["constant_filter"]: - filter_input = create_tensor_data(np.float32, parameters["filter_shape"]) + filter_input = create_tensor_data(np.float32, filter_shape) input_tensors = [input_tensor] else: filter_input = tf.placeholder( - dtype=tf.float32, name="filter", shape=parameters["filter_shape"]) + dtype=tf.float32, name="filter", shape=filter_shape) input_tensors = [input_tensor, filter_input] out = tf.nn.conv2d( @@ -1082,9 +1150,10 @@ def make_conv_tests(zip_path): def build_inputs(parameters, sess, inputs, outputs): # Build list of input values either containing 1 tensor (input) or 2 tensors # (input, filter) based on whether filter is constant or variable input. - values = [create_tensor_data(np.float32, parameters["input_shape"])] + input_shape, filter_shape = get_tensor_shapes(parameters) + values = [create_tensor_data(np.float32, input_shape)] if not parameters["constant_filter"]: - values.append(create_tensor_data(np.float32, parameters["filter_shape"])) + values.append(create_tensor_data(np.float32, filter_shape)) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -1316,10 +1385,10 @@ def make_local_response_norm_tests(zip_path): # Chose a set of parameters test_parameters = [{ "input_shape": [[1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]], - "depth_radius": [None, 0, 1, 3, 4, 5], - "bias": [None, 0.1, 0.3, -0.1], - "alpha": [None, 1, 2, -3], - "beta": [None, 0.5, 0.25, 2], + "depth_radius": [None, 0, 1, 3, 5], + "bias": [None, 0.3, -0.1], + "alpha": [None, 2, -3], + "beta": [None, 0.25, 2], }] def build_graph(parameters): @@ -1551,7 +1620,7 @@ def make_space_to_depth_tests(zip_path): """Make a set of tests to do space_to_depth.""" test_parameters = [{ - "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64], + "dtype": [tf.float32, tf.int32, tf.uint8, tf.int64], "input_shape": [[2, 12, 24, 1]], "block_size": [2, 3, 4], }] @@ -1794,65 +1863,8 @@ def make_squeeze_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_strided_slice_tests(zip_path): - """Make a set of tests to do strided_slice.""" - - # TODO(soroosh): add test/support for uint8. - test_parameters = [ - # 4-D - { - "dtype": [tf.float32, tf.int32, tf.int64], - "index_type": [tf.int32], - "input_shape": [[12, 2, 2, 5]], - "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], - "end": [[8, 2, 2, 3], [12, 2, 2, 5]], - "strides": [None, [2, 1, 3, 1]], - "begin_mask": [None, 1, 8], - "end_mask": [None, 1, 8], - "shrink_axis_mask": [None, 1, 8, 11, 15, -1], - "constant_indices": [False, True], - }, - # TODO(b/73170889) Restore test parameters removed in cl/191608113. - # 2-D - { - "dtype": [tf.float32, tf.int32, tf.int64], - "index_type": [tf.int32], - "input_shape": [[2, 3]], - "begin": [[0, 0], [1, 0]], - "end": [[2, 3], [2, 2]], - "strides": [None, [2, 2]], - "begin_mask": [None, 1, 2], - "end_mask": [None, 1, 2], - "shrink_axis_mask": [None, 1, 2, 3, -1], - "constant_indices": [False, True], - }, - # 1-D Exhaustive - { - "dtype": [tf.float32], - "index_type": [tf.int32], - "input_shape": [[4]], - "begin": [[-100], [-3], [-2], [-1], [0], [1], [2], [3], [100]], - "end": [[-100], [-3], [-2], [-1], [0], [1], [2], [3], [100]], - "strides": [-2, -1, 1, 2], - "begin_mask": [0, 1], - "end_mask": [0, 1], - "shrink_axis_mask": [0], - "constant_indices": [False], - }, - # Negative strides - { - "dtype": [tf.float32], - "index_type": [tf.int32], - "input_shape": [[2, 3]], - "begin": [[0, -1]], - "end": [[2, -3]], - "strides": [[1, -1]], - "begin_mask": [None, 1, 2], - "end_mask": [None, 1, 2], - "shrink_axis_mask": [None, 1, 2, 3, -1], - "constant_indices": [False], - }, - ] +def _make_strided_slice_tests(zip_path, test_parameters): + """Utility function to make strided_slice_tests based on parameters.""" def build_graph(parameters): """Build graph for stride_slice test.""" @@ -1914,6 +1926,100 @@ def make_strided_slice_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_strided_slice_tests(zip_path): + """Make a set of tests to do strided_slice.""" + + # TODO(soroosh): add test/support for uint8. + test_parameters = [ + # 4-D (basic cases with const/non-const indices). + { + "dtype": [tf.float32, tf.int32, tf.int64], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "strides": [None, [2, 1, 3, 1]], + "begin": [[0, 0, 0, 0]], + "end": [[12, 2, 2, 5]], + "begin_mask": [None], + "end_mask": [None], + "shrink_axis_mask": [None], + "constant_indices": [False, True], + }, + # 4-D with non-trivial begin & end. + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], + "end": [[8, 2, 2, 3], [12, 2, 2, 5]], + "strides": [None, [2, 1, 3, 1]], + "begin_mask": [None, 8], + "end_mask": [None, 3], + "shrink_axis_mask": [None, 15, -1], + "constant_indices": [True], + }, + # Begin, end, strides dim are different from input shape + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0]], + "end": [[1]], + "strides": [None, [1]], + "begin_mask": [0], + "end_mask": [0], + "shrink_axis_mask": [1], + "constant_indices": [True], + }, + # 2-D + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[2, 3]], + "begin": [[0, 0]], + "end": [[2, 2]], + "strides": [None, [2, 2]], + "begin_mask": [None, 1, 2], + "end_mask": [None, 1, 2], + "shrink_axis_mask": [None, 1, 2, 3, -1], + "constant_indices": [False, True], + }, + # Negative strides + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[2, 3]], + "begin": [[0, -1]], + "end": [[2, -3]], + "strides": [[1, -1]], + "begin_mask": [None, 1, 2], + "end_mask": [None, 1, 2], + "shrink_axis_mask": [None, 1, 2, 3, -1], + "constant_indices": [False], + }, + ] + _make_strided_slice_tests(zip_path, test_parameters) + + +def make_strided_slice_1d_exhaustive_tests(zip_path): + """Make a set of exhaustive tests for 1D strided_slice.""" + test_parameters = [ + # 1-D Exhaustive + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[3]], + "begin": [[-2], [-1], [0], [1], [2]], + "end": [[-2], [-1], [0], [1], [2]], + "strides": [[-2], [-1], [1], [2]], + "begin_mask": [0, 1], + "end_mask": [0, 1], + "shrink_axis_mask": [0], + "constant_indices": [False], + }, + ] + _make_strided_slice_tests(zip_path, test_parameters) + + def make_lstm_tests(zip_path): """Make a set of tests to do basic Lstm cell.""" @@ -1924,6 +2030,7 @@ def make_lstm_tests(zip_path): "time_step_size": [1], "input_vec_size": [3], "num_cells": [4], + "split_tflite_lstm_inputs": [True, False], }, ] @@ -2032,8 +2139,8 @@ def make_arg_max_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32, tf.int32], "input_shape": [[1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]], - "axis": [0, 1, 2, 3], "output_type": [tf.int32, tf.int64], + "axis_is_last_dim": [True, False], }] def build_graph(parameters): @@ -2042,7 +2149,10 @@ def make_arg_max_tests(zip_path): dtype=parameters["input_dtype"], name="input", shape=parameters["input_shape"]) - axis = tf.constant(parameters["axis"], name="axis") + if parameters["axis_is_last_dim"]: + axis = len(parameters["input_shape"]) - 1 + else: + axis = random.randint(0, max(len(parameters["input_shape"]) - 2, 0)) out = tf.arg_max(input_value, axis, output_type=parameters["output_type"]) return [input_value], [out] @@ -2055,6 +2165,74 @@ def make_arg_max_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_equal_tests(zip_path): + """Make a set of tests to do equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the equal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_not_equal_tests(zip_path): + """Make a set of tests to do not equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the not euqal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.not_equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_greater_tests(zip_path): """Make a set of tests to do greater.""" @@ -2242,6 +2420,46 @@ def make_neg_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def _make_elementwise_tests(op): + """Make a set of tests to do element-wise operations.""" + + def f(zip_path): + """Actual function that generates examples.""" + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + }] + + def build_graph(parameters): + """Build the sin op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape"]) + out = op(input_value) + return [input_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict={inputs[0]: input_value}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return f + + +def make_sin_tests(zip_path): + """Make a set of tests to do sin.""" + return _make_elementwise_tests(tf.sin)(zip_path) + + +def make_log_tests(zip_path): + """Make a set of tests to do log.""" + return _make_elementwise_tests(tf.log)(zip_path) + + def make_where_tests(zip_path): """Make a set of tests to do where.""" @@ -2274,9 +2492,257 @@ def make_where_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + +def make_slice_tests(zip_path): + """Make a set of tests to do slice.""" + + # TODO(renjieliu): add test/support for uint8. + test_parameters = [ + # 4-D + { + "dtype": [tf.float32, tf.int32, tf.int64], + "index_type": [tf.int32, tf.int64], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], + "size": [[8, 2, 2, 3], [11, 2, 1, 5]], + }, + # 2-D + { + "dtype": [tf.float32, tf.int32, tf.int64], + "index_type": [tf.int32, tf.int64], + "input_shape": [[2, 3]], + "begin": [[0, 0], [1, 0]], + "size": [[2, 3], [2, 2]], + }, + ] + + def build_graph(parameters): + """Build graph for slice test.""" + input_tensor = tf.placeholder( + dtype=parameters["dtype"], + name="input", + shape=parameters["input_shape"]) + begin = tf.placeholder( + dtype=parameters["index_type"], + name="begin", + shape=[len(parameters["input_shape"])]) + size = tf.placeholder( + dtype=parameters["index_type"], + name="size", + shape=[len(parameters["input_shape"])]) + tensors = [input_tensor, begin, size] + out = tf.slice(input_tensor, begin, size) + return tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Build inputs for slice test.""" + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + index_type = _TF_TYPE_INFO[parameters["index_type"]][0] + + begin_values = np.array(parameters["begin"]).astype(index_type) + size_values = np.array(parameters["size"]).astype(index_type) + values = [input_values, begin_values, size_values] + + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +# Since compute output_shape is fairly complicated for +# tf.nn.conv2d_backprop_input input_sizes argument, so we here first perform a +# "conv2d" operation to get the output, then we use the output to feed in +# tf.nn.conv2d_backprop_input. +# This test will depend on the "conv2d" operation's correctness. +def make_transpose_conv_tests(zip_path): + """Make a set of tests to do transpose_conv.""" + + # Tensorflow only supports equal strides + test_parameters = [{ + "input_shape": [[1, 3, 4, 1], [1, 10, 10, 3], [3, 20, 20, 1]], + "filter_size": [[1, 1], [1, 2], [3, 3]], + "strides": [[1, 1, 1, 1], [1, 3, 3, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], + "channel_multiplier": [1, 2], + }] + + def get_tensor_shapes(parameters): + input_shape = parameters["input_shape"] + filter_size = parameters["filter_size"] + filter_shape = filter_size + [ + input_shape[3], parameters["channel_multiplier"] + ] + return [input_shape, filter_shape] + + def build_graph(parameters): + """Build a transpose_conv graph given `parameters`.""" + input_shape, filter_shape = get_tensor_shapes(parameters) + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=input_shape) + + filter_input = tf.placeholder( + dtype=tf.float32, name="filter", shape=filter_shape) + + conv_outputs = tf.nn.conv2d( + input_tensor, + filter_input, + strides=parameters["strides"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + out = tf.nn.conv2d_backprop_input( + input_shape, + filter_input, + conv_outputs, + strides=parameters["strides"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + input_tensors = [input_tensor, filter_input] + return input_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_shape, filter_shape = get_tensor_shapes(parameters) + values = [ + create_tensor_data(np.float32, input_shape), + create_tensor_data(np.float32, filter_shape) + ] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_tile_tests(zip_path): + """Make a set of tests to do tile.""" + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[3, 2, 1], [2, 2, 2]], + "multiplier_dtype": [tf.int32, tf.int64], + "multiplier_shape": [[3]] + }] + + def build_graph(parameters): + """Build the tile op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + shape=parameters["input_shape"], + name="input") + multiplier_value = tf.placeholder( + dtype=parameters["multiplier_dtype"], + shape=parameters["multiplier_shape"], + name="multiplier") + out = tf.tile(input_value, multiplier_value) + return [input_value, multiplier_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + multipliers_value = create_tensor_data(parameters["multiplier_dtype"], + parameters["multiplier_shape"]) + return [input_value, multipliers_value], sess.run( + outputs, + feed_dict={ + inputs[0]: input_value, + inputs[1]: multipliers_value + }) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_expand_dims_tests(zip_path): + """Make a set of tests to do expand_dims.""" + + test_parameters = [{ + "input_type": [tf.float32, tf.int32], + "input_shape": [[3, 4], [10, 10, 3]], + "axis_value": [0, 1, 2, -1, -2], + }] + + def build_graph(parameters): + """Build the where op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_type"], + name="input", + shape=parameters["input_shape"]) + axis_value = tf.placeholder(dtype=tf.int32, name="axis", shape=[1]) + out = tf.expand_dims(input_value, axis=axis_value) + return [input_value, axis_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_type"], + parameters["input_shape"]) + axis_value = np.array([parameters["axis_value"]], dtype=np.int32) + return [input_value, axis_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value, axis_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_sparse_to_dense_tests(zip_path): + """Make a set of tests to do sparse to dense.""" + + test_parameters = [{ + "value_dtype": [tf.float32, tf.int32], + "index_dtype": [tf.int32, tf.int64], + "value_count": [1, 3, 6, 8], + "dense_shape": [[15], [3, 10], [4, 4, 4, 4], [7, 10, 9]], + "default_value": [0, -1], + "value_is_scalar": [True, False], + }] + + # Return a single value for 1-D dense shape, but a tuple for other shapes. + def generate_index(dense_shape): + if len(dense_shape) == 1: + return np.random.randint(dense_shape[0]) + else: + index = [] + for shape in dense_shape: + index.append(np.random.randint(shape)) + return tuple(index) + + def build_graph(parameters): + """Build the sparse_to_dense op testing graph.""" + dense_shape = parameters["dense_shape"] + + # Special handle for value_is_scalar case. + # value_count must be 1. + if parameters["value_is_scalar"] and parameters["value_count"] == 1: + value = tf.placeholder( + name="value", dtype=parameters["value_dtype"], shape=()) + else: + value = tf.placeholder( + name="value", + dtype=parameters["value_dtype"], + shape=[parameters["value_count"]]) + indices = set() + while len(indices) < parameters["value_count"]: + indices.add(generate_index(dense_shape)) + indices = tf.constant(tuple(indices), dtype=parameters["index_dtype"]) + # TODO(renjieliu): Add test for validate_indices case. + out = tf.sparse_to_dense( + indices, + dense_shape, + value, + parameters["default_value"], + validate_indices=False) + + return [value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + if parameters["value_is_scalar"] and parameters["value_count"] == 1: + input_value = create_scalar_data(parameters["value_dtype"]) + else: + input_value = create_tensor_data(parameters["value_dtype"], + [parameters["value_count"]]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + # Toco binary path provided by the generate rule. bin_path = None + def main(unused_args): global bin_path def mkdir_if_not_exist(x): diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc index 6580845af42b3cdded19b578b41c682089aaf9ef..c0c861ff6da2fc144b9303dfdd48f19794cebeca 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.cc +++ b/tensorflow/contrib/lite/testing/generate_testspec.cc @@ -80,11 +80,30 @@ bool GenerateTestSpecFromTensorflowModel( // Invoke tensorflow model. TfDriver runner(input_layer, input_layer_type, input_layer_shape, output_layer); + if (!runner.IsValid()) { + cerr << runner.GetErrorMessage() << endl; + return false; + } + runner.LoadModel(tensorflow_model_path); + if (!runner.IsValid()) { + cerr << runner.GetErrorMessage() << endl; + return false; + } + for (int i = 0; i < input_values.size(); i++) { runner.SetInput(i, input_values[i]); + if (!runner.IsValid()) { + cerr << runner.GetErrorMessage() << endl; + return false; + } } + runner.Invoke(); + if (!runner.IsValid()) { + cerr << runner.GetErrorMessage() << endl; + return false; + } // Write test spec. stream << "load_model: " << tflite_model_path << "\n"; @@ -99,6 +118,10 @@ bool GenerateTestSpecFromTensorflowModel( } for (int i = 0; i < output_layer.size(); i++) { stream << " output: \"" << runner.ReadOutput(i) << "\"\n"; + if (!runner.IsValid()) { + cerr << runner.GetErrorMessage() << endl; + return false; + } } stream << "}\n"; diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 49762bdfe7139cff5c40b8609dd11435f2548175..e85020448a572650c6a70d8b4dcb4e73faf0f8c8 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -35,7 +35,7 @@ namespace { bool FLAGS_ignore_known_bugs = true; // TODO(b/71769302) zip_files_dir should have a more accurate default, if // possible -string* FLAGS_zip_files_dir = new string("./"); +string* FLAGS_zip_file_path = new string("./"); string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip"); } // namespace @@ -48,7 +48,7 @@ tensorflow::Env* env = tensorflow::Env::Default(); // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { // Add only supports float32. (and "constant" tests use Add) - {R"(^\/adda.*int32)", "68808744"}, + {R"(^\/add_a.*int32)", "68808744"}, {R"(^\/constant.*int32)", "68808744"}, {R"(^\/mul.*int32)", "68808744"}, {R"(^\/div.*int32)", "68808744"}, @@ -61,29 +61,25 @@ std::map kBrokenTests = { "70527055"}, // L2Norm only supports tensors with 4D or fewer. - {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, - - // BatchToSpaceND doesn't support cropping. This catches test cases with - // non-const tensors as crops. - {R"(^\/batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\])", "70594634"}, + {R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, // SpaceToBatchND only supports 4D tensors. {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, // L2Norm only works for dim=-1. - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, // ResizeBilinear looks completely incompatible with Tensorflow {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"}, @@ -98,9 +94,11 @@ std::map kBrokenTests = { {R"(^\/gather.*axis=1)", "76910444"}, // No support for arbitrary dimensions in ArgMax. - {R"(^\/arg_max.*axis=0)", "77546240"}, - {R"(^\/arg_max.*axis=1)", "77546240"}, - {R"(^\/arg_max.*axis=2)", "77546240"}, + {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])", + "77546240"}, + {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])", + "77546240"}, + {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"}, }; // Allows test data to be unzipped into a temporary directory and makes @@ -139,7 +137,10 @@ class ZipEnvironment : public ::testing::Environment { *out_dir = dir; return tensorflow::Status::OK(); } else { - return tensorflow::Status(tensorflow::error::UNKNOWN, "unzip failed"); + return tensorflow::Status(tensorflow::error::UNKNOWN, + "unzip failed. " + "stdout:\n" + + out + "\nstderr:\n" + err); } } @@ -193,8 +194,7 @@ tensorflow::Status ReadManifest(const string& original_file, const string& dir, } // Get a list of tests from a zip file `zip_file_name`. -std::vector UnarchiveZipAndFindTestNames(const string& zip_file_name) { - string zip_file = *FLAGS_zip_files_dir + "/" + zip_file_name; +std::vector UnarchiveZipAndFindTestNames(const string& zip_file) { string decompress_tmp_dir; TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir)); std::vector stuff; @@ -204,7 +204,7 @@ std::vector UnarchiveZipAndFindTestNames(const string& zip_file_name) { class OpsTest : public ::testing::TestWithParam {}; -TEST_P(OpsTest, RunStuff) { +TEST_P(OpsTest, RunZipTests) { string test_path = GetParam(); string tflite_test_case = test_path + "_tests.txt"; string tflite_dir = test_path.substr(0, test_path.find_last_of("/")); @@ -227,7 +227,9 @@ TEST_P(OpsTest, RunStuff) { EXPECT_TRUE(result) << test_driver.GetErrorMessage(); } else { if (FLAGS_ignore_known_bugs) { - EXPECT_FALSE(result); + EXPECT_FALSE(result) << "Test was expected to fail but is now passing; " + "you can mark http://b/" + << bug_number << " as fixed! Yay!"; } else { EXPECT_TRUE(result) << test_driver.GetErrorMessage() << ": Possibly due to http://b/" << bug_number; @@ -235,61 +237,26 @@ TEST_P(OpsTest, RunStuff) { } } -// Instantiate a test. This assumes `zip_base`.zip is a declared data file -// of this test. -#define INSTANTIATE_TESTS(zip_base) \ - INSTANTIATE_TEST_CASE_P( \ - zip_base, OpsTest, \ - ::testing::ValuesIn(UnarchiveZipAndFindTestNames(#zip_base ".zip"))); - -INSTANTIATE_TESTS(add) -INSTANTIATE_TESTS(arg_max) -INSTANTIATE_TESTS(avg_pool) -INSTANTIATE_TESTS(batch_to_space_nd) -INSTANTIATE_TESTS(concat) -INSTANTIATE_TESTS(constant) -INSTANTIATE_TESTS(control_dep) -INSTANTIATE_TESTS(conv) -INSTANTIATE_TESTS(depthwiseconv) -INSTANTIATE_TESTS(div) -INSTANTIATE_TESTS(exp) -INSTANTIATE_TESTS(floor) -INSTANTIATE_TESTS(fully_connected) -INSTANTIATE_TESTS(fused_batch_norm) -INSTANTIATE_TESTS(gather) -INSTANTIATE_TESTS(global_batch_norm) -INSTANTIATE_TESTS(greater) -INSTANTIATE_TESTS(greater_equal) -INSTANTIATE_TESTS(l2_pool) -INSTANTIATE_TESTS(l2norm) -INSTANTIATE_TESTS(less) -INSTANTIATE_TESTS(less_equal) -INSTANTIATE_TESTS(local_response_norm) -INSTANTIATE_TESTS(log_softmax) -INSTANTIATE_TESTS(max_pool) -INSTANTIATE_TESTS(maximum) -INSTANTIATE_TESTS(mean) -INSTANTIATE_TESTS(minimum) -INSTANTIATE_TESTS(mul) -INSTANTIATE_TESTS(neg) -INSTANTIATE_TESTS(pad) -INSTANTIATE_TESTS(padv2) -// INSTANTIATE_TESTS(prelu) -INSTANTIATE_TESTS(relu) -INSTANTIATE_TESTS(relu1) -INSTANTIATE_TESTS(relu6) -INSTANTIATE_TESTS(reshape) -INSTANTIATE_TESTS(resize_bilinear) -INSTANTIATE_TESTS(sigmoid) -INSTANTIATE_TESTS(softmax) -INSTANTIATE_TESTS(space_to_batch_nd) -INSTANTIATE_TESTS(space_to_depth) -INSTANTIATE_TESTS(split) -INSTANTIATE_TESTS(squeeze) -INSTANTIATE_TESTS(strided_slice) -INSTANTIATE_TESTS(sub) -INSTANTIATE_TESTS(transpose) -INSTANTIATE_TESTS(where) +struct ZipPathParamName { + template + string operator()(const ::testing::TestParamInfo& info) const { + string param_name = info.param; + size_t last_slash = param_name.find_last_of("\\/"); + if (last_slash != string::npos) { + param_name = param_name.substr(last_slash); + } + for (size_t index = 0; index < param_name.size(); ++index) { + if (!isalnum(param_name[index]) && param_name[index] != '_') + param_name[index] = '_'; + } + return param_name; + } +}; + +INSTANTIATE_TEST_CASE_P( + tests, OpsTest, + ::testing::ValuesIn(UnarchiveZipAndFindTestNames(*FLAGS_zip_file_path)), + ZipPathParamName()); } // namespace testing } // namespace tflite @@ -302,8 +269,8 @@ int main(int argc, char** argv) { "ignore_known_bugs", &tflite::testing::FLAGS_ignore_known_bugs, "If a particular model is affected by a known bug, the " "corresponding test should expect the outputs to not match."), - tensorflow::Flag("zip_files_dir", tflite::testing::FLAGS_zip_files_dir, - "Required: Location of the test zips."), + tensorflow::Flag("zip_file_path", tflite::testing::FLAGS_zip_file_path, + "Required: Location of the test zip file."), tensorflow::Flag("unzip_binary_path", tflite::testing::FLAGS_unzip_binary_path, "Required: Location of a suitable unzip binary.")}; diff --git a/tensorflow/contrib/lite/testing/join.h b/tensorflow/contrib/lite/testing/join.h index ce8c072a21c6e61e8ab8ae12ba52418e6144009a..1edee01cf97da3c53be1895e667b005551ac2991 100644 --- a/tensorflow/contrib/lite/testing/join.h +++ b/tensorflow/contrib/lite/testing/join.h @@ -22,7 +22,7 @@ limitations under the License. namespace tflite { namespace testing { -// Join a list of data separated by delimieter. +// Join a list of data separated by delimiter. template string Join(T* data, size_t len, const string& delimiter) { if (len == 0 || data == nullptr) { @@ -36,6 +36,22 @@ string Join(T* data, size_t len, const string& delimiter) { return result.str(); } +// Join a list of uint8 data separated by a delimiter. Cast data to int before +// placing it in the string to prevent values from being treated like chars. +template <> +inline string Join(uint8_t* data, size_t len, + const string& delimiter) { + if (len == 0 || data == nullptr) { + return ""; + } + std::stringstream result; + result << static_cast(data[0]); + for (int i = 1; i < len; i++) { + result << delimiter << static_cast(data[i]); + } + return result.str(); +} + } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h index 05770beee23275ebe210606dbfd2b33eea17612d..96ab6be54e528334f9e4a8cc259e44f99878fefb 100644 --- a/tensorflow/contrib/lite/testing/test_runner.h +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ #define TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ +#include #include #include #include @@ -89,6 +90,7 @@ class TestRunner { // Invalidate the test runner, preventing it from executing any further. void Invalidate(const string& error_message) { + cerr << error_message << std::endl; error_message_ = error_message; } bool IsValid() const { return error_message_.empty(); } diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc index 7b295875aab12bf48da2341ce05dd53442464cf0..3b27f6f3da92ce80c3830feb7c6af095e7c48e9c 100644 --- a/tensorflow/contrib/lite/testing/tf_driver.cc +++ b/tensorflow/contrib/lite/testing/tf_driver.cc @@ -103,7 +103,7 @@ void TfDriver::LoadModel(const string& bin_file_path) { session_.reset(tensorflow::NewSession(options)); auto status = session_->Create(graphdef); if (!status.ok()) { - Invalidate("Failed to create session"); + Invalidate("Failed to create session. " + status.error_message()); } } diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 75ac24719aa8fad960ae06d006eda386d44d721a..fc28faf52405b300dc6e4f0aab33122bb5e98f12 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/testing/split.h" namespace tflite { @@ -143,6 +144,7 @@ void TfLiteDriver::AllocateTensors() { Invalidate("Failed to allocate tensors"); return; } + ResetLSTMStateTensors(); must_allocate_tensors_ = false; } } @@ -281,5 +283,36 @@ bool TfLiteDriver::CheckResults() { return success; } +void TfLiteDriver::ResetLSTMStateTensors() { + // This is a workaround for initializing state tensors for LSTM. + // TODO(ycling): Refactoring and find a better way to initialize state + // tensors. Maybe write the reset instructions into the test data. + for (auto node_index : interpreter_->execution_plan()) { + const auto& node_and_reg = interpreter_->node_and_registration(node_index); + const auto& node = node_and_reg->first; + const auto& registration = node_and_reg->second; + + if (registration.builtin_code == tflite::BuiltinOperator_LSTM) { + const auto* params = + reinterpret_cast(node.builtin_data); + if (params->kernel_type == kTfLiteLSTMFullKernel && + node.outputs->size >= 2) { + // The first 2 outputs of LSTM are state tensors. + for (int i = 0; i < 2; ++i) { + int node_index = node.outputs->data[i]; + ResetTensor(node_index); + } + } else if (params->kernel_type == kTfLiteLSTMBasicKernel && + node.inputs->size == 5) { + // The 2th and 5th inputs are state tensors. + for (int i : {1, 4}) { + int node_index = node.inputs->data[i]; + ResetTensor(node_index); + } + } + } + } +} + } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h index 02b7de1534e648734d7bc53154afa42f2ef256b4..5493ba3631b0423942cc9c4f98fbd6393a404060 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.h +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -48,6 +48,8 @@ class TfLiteDriver : public TestRunner { string ReadOutput(int id) override { return "no-op"; } private: + void ResetLSTMStateTensors(); + class Expectation; bool use_nnapi_ = false; diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 01ce0d9db21222a01a9a35363fa82347bf5a690d..7ea4f32ef694f3b0dc9c030b9440268ac79848aa 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -245,6 +245,7 @@ cc_library( "graph_transformations/quantization_util.cc", "graph_transformations/quantization_util.h", "graph_transformations/quantize.cc", + "graph_transformations/quantize_weights.cc", "graph_transformations/read_fake_quant_min_max.cc", "graph_transformations/remove_final_dequantize_op.cc", "graph_transformations/remove_tensorflow_assert.cc", @@ -273,6 +274,7 @@ cc_library( "graph_transformations/resolve_constant_range.cc", "graph_transformations/resolve_constant_reshape.cc", "graph_transformations/resolve_constant_shape_or_rank.cc", + "graph_transformations/resolve_constant_slice.cc", "graph_transformations/resolve_constant_stack.cc", "graph_transformations/resolve_constant_strided_slice.cc", "graph_transformations/resolve_constant_transpose.cc", diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 6c0311af0a926711955caaa1c7507d7c52c77069..9f5ca66d050f0ead9b8856c77dba8d9bbd182d10 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -234,6 +234,7 @@ struct ParsedTocoFlags { Arg drop_fake_quant = Arg(false); Arg reorder_across_fake_quant = Arg(false); Arg allow_custom_ops = Arg(false); + Arg quantize_weights = Arg(false); // Deprecated flags Arg input_type; Arg input_types; @@ -242,6 +243,7 @@ struct ParsedTocoFlags { Arg propagate_fake_quant_num_bits = Arg(false); Arg allow_nudging_weights_to_use_fast_gemm_kernel = Arg(false); Arg dedupe_array_min_size_bytes = Arg(64); + Arg split_tflite_lstm_inputs = Arg(true); }; } // namespace toco diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 166ead918471ee1b06d9683b8dc7baf7bcbdc427..8913b5c3ea962725ef2bed73e670e8f0b988a591 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -16,8 +16,6 @@ limitations under the License. #include #include -#include -#include #include #include "absl/strings/str_replace.h" @@ -91,10 +89,7 @@ Color GetColorForArray(const Model& model, const string& array_name) { // We use gray colors for them because they are the majority // of arrays so we want to highlight other arrays instead of them. // First, we use a bolder gray for input/output arrays: - const auto& dump_options = *GraphVizDumpOptions::singleton(); - if (IsInputArray(model, array_name) || - array_name == dump_options.graphviz_first_array || - array_name == dump_options.graphviz_last_array) { + if (IsInputArray(model, array_name)) { return Color(0x9E, 0x9E, 0x9E); } if (IsOutputArray(model, array_name)) { @@ -137,6 +132,12 @@ void AppendArrayVal(string* string, Array const& array, int index) { return; } AppendF(string, "%d", data[index]); + } else if (array.buffer->type == ArrayDataType::kBool) { + const auto& data = array.GetBuffer().data; + if (index >= data.size()) { + return; + } + AppendF(string, "%d", data[index]); } } @@ -287,47 +288,6 @@ NodeProperties GetPropertiesForOperator(const Operator& op) { return node_properties; } -std::vector OperatorsToDump(const Model& model) { - const auto& dump_options = *GraphVizDumpOptions::singleton(); - bool first_specified = !dump_options.graphviz_first_array.empty(); - bool last_specified = !dump_options.graphviz_last_array.empty(); - CHECK_EQ(first_specified, last_specified); - std::vector ops_to_dump; - if (last_specified) { - // Return only the part of the graph between graphviz_first_array - // and graphviz_last_array. - CHECK(model.HasArray(dump_options.graphviz_first_array)); - CHECK(model.HasArray(dump_options.graphviz_last_array)); - std::unordered_set arrays_already_produced; - std::vector arrays_to_produce; - arrays_to_produce.push_back(dump_options.graphviz_last_array); - while (!arrays_to_produce.empty()) { - const string array = arrays_to_produce.back(); - arrays_to_produce.pop_back(); - CHECK(!arrays_already_produced.count(array)); - arrays_already_produced.insert(array); - const Operator* op = GetOpWithOutput(model, array); - if (!op) { - continue; - } - ops_to_dump.push_back(op); - for (const string& input : op->inputs) { - if (arrays_already_produced.count(input) || - input == dump_options.graphviz_first_array) { - continue; - } - arrays_to_produce.push_back(input); - } - } - } else { - // Return the whole graph. - for (const auto& op : model.operators) { - ops_to_dump.push_back(op.get()); - } - } - return ops_to_dump; -} - } // namespace void DumpGraphviz(const Model& model, string* output_file_contents) { @@ -348,30 +308,30 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { constexpr char kRNNBackEdgeFormat[] = "\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n"; - std::vector ops_to_dump = OperatorsToDump(model); - std::set already_added_arrays; - for (int op_index = 0; op_index < ops_to_dump.size(); op_index++) { - const Operator& op = *ops_to_dump[op_index]; + for (const auto& array_kv : model.GetArrayMap()) { + // Add node for array. + const string& array_name = array_kv.first; + const auto& array_properties = GetPropertiesForArray(model, array_name); + AppendF(output_file_contents, kNodeFormat, array_name, + array_properties.label, "octagon", + array_properties.color.FillColorString().c_str(), + array_properties.color.TextColorString().c_str()); + } + for (int op_index = 0; op_index < model.operators.size(); op_index++) { + const Operator& op = *model.operators[op_index]; // Add node for operator. auto op_properties = GetPropertiesForOperator(op); string operator_id = StringF("op%05d", op_index); AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label, "box", op_properties.color.FillColorString().c_str(), op_properties.color.TextColorString().c_str()); - // Add nodes and edges for all inputs of the operator. + // Add edges for all inputs of the operator. for (const auto& input : op.inputs) { if (!model.HasArray(input)) { // Arrays should _always_ exist. Except, perhaps, during development. continue; } auto array_properties = GetPropertiesForArray(model, input); - if (!already_added_arrays.count(input)) { - AppendF(output_file_contents, kNodeFormat, input, - array_properties.label, "octagon", - array_properties.color.FillColorString().c_str(), - array_properties.color.TextColorString().c_str()); - } - // Draw lines that transport more data thicker (Otherwise, where would the // data fit? right?). float line_width = @@ -387,22 +347,14 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { } AppendF(output_file_contents, kEdgeFormat, input, operator_id, line_width, weight); - already_added_arrays.insert(input); } - // Add nodes and edges for all outputs of the operator. + // Add edges for all outputs of the operator. for (const auto& output : op.outputs) { if (!model.HasArray(output)) { // Arrays should _always_ exist. Except, perhaps, during development. continue; } auto array_properties = GetPropertiesForArray(model, output); - if (!already_added_arrays.count(output)) { - AppendF(output_file_contents, kNodeFormat, output, - array_properties.label, "octagon", - array_properties.color.FillColorString().c_str(), - array_properties.color.TextColorString().c_str()); - } - // See comments above regarding weight and line_width calculations. float line_width = std::max(0.5f, array_properties.log2_buffer_size / 3.0f); @@ -412,7 +364,6 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { } AppendF(output_file_contents, kEdgeFormat, operator_id, output, line_width, weight); - already_added_arrays.insert(output); } } diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index f5157149afca17383a8625c489f15a23ce6dd224..c7c80ab21cc88ecb94e5234750431130dbd8400c 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -494,7 +494,7 @@ void ConvertTransposeConvOperator(const Model& model, const auto& weights_array = model.GetArray(weights_array_name); CHECK(weights_array.buffer->type == ArrayDataType::kFloat); ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI, - AxesOrder::kHWIO, tensorflow_graph); + AxesOrder::kHWOI, tensorflow_graph); auto& strides = (*conv2d_op->mutable_attr())["strides"]; strides.mutable_list()->add_i(1); strides.mutable_list()->add_i(src_op.stride_height); @@ -1728,6 +1728,25 @@ void ConvertComparisonOperator(const Model& model, const Operator& src_op, (*comparison_op->mutable_attr())["T"].set_type(data_type); } +void ConvertSparseToDenseOperator(const Model& model, + const SparseToDenseOperator& src_op, + const char* op_name, + GraphDef* tensorflow_graph) { + auto* sparse_to_dense_op = tensorflow_graph->add_node(); + sparse_to_dense_op->set_op(op_name); + sparse_to_dense_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 4); + for (int i = 0; i < 4; ++i) { + *sparse_to_dense_op->add_input() = src_op.inputs[i]; + } + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[3]); + (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type); + const auto index_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type); + (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b( + src_op.validate_indices); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1919,6 +1938,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertRandomUniformOperator( model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowEqual) { + ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowNotEqual) { + ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph); } else if (src_op.type == OperatorType::kTensorFlowGreater) { ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph); } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) { diff --git a/tensorflow/contrib/lite/toco/format_port.h b/tensorflow/contrib/lite/toco/format_port.h index eb81e90faf20133ed722185928f86ef45ac4f8f6..44e668457152376fd8b2e2fa063301468090c3f0 100644 --- a/tensorflow/contrib/lite/toco/format_port.h +++ b/tensorflow/contrib/lite/toco/format_port.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file is used to provide equivalents of internal util::format::FormatF -// and util::format::AppendF. Unfortunately, type safety is not as good as a +// This file is used to provide equivalents of internal absl::FormatF +// and absl::StrAppendFormat. Unfortunately, type safety is not as good as a // a full C++ example. // TODO(aselle): When absl adds support for StrFormat, use that instead. #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md index 7680cdd344814bf6cbc7bbe11c915f220642d55d..8e93f02ef109f7bccd07ce54baff3d0bb4ae50c7 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md @@ -26,8 +26,6 @@ Table of contents: * [Convert a TensorFlow Lite FlatBuffer back into TensorFlow GraphDef format](#to-graphdef) * [Logging](#logging) - * [Standard logging](#standard-logging) - * [Verbose logging](#verbose-logging) * [Graph "video" logging](#graph-video-logging) * [Graph visualizations](#graph-visualizations) * [Using --output_format=GRAPHVIZ_DOT](#using-output-formatgraphviz-dot) @@ -277,49 +275,6 @@ bazel run --config=opt \ ## Logging -### Standard logging - -The converter generates some informative log messages during processing. The -easiest way to view them is to add `--logtostderr` to command lines as seen in -the following example. - -``` -curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ - | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ - --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ - --logtostderr -``` - -After some initialization messages, we get the following informative messages: - -``` -I1101 21:51:33.297475 5339 graph_transformations.cc:39] Before general graph transformations: 416 operators, 583 arrays (0 quantized) -I1101 21:51:33.308972 5339 graph_transformations.cc:39] After general graph transformations pass 1: 31 operators, 89 arrays (0 quantized) -I1101 21:51:33.309204 5339 graph_transformations.cc:39] Before dequantization graph transformations: 31 operators, 89 arrays (0 quantized) -I1101 21:51:33.309368 5339 allocate_transient_arrays.cc:312] Total transient array allocated size: 1048576 bytes, theoretical optimal value: 786432 bytes. -I1101 21:51:33.309484 5339 toco_tooling.cc:249] Estimated count of arithmetic ops: 0.099218 billion (note that a multiply-add is counted as 2 ops). -``` - -### Verbose logging - -For debugging purposes, the converter supports two levels of verbose logging, -which can be set by passing a `--v=` flag: - -* For `--v=1`, the converter generates text dumps of the graph at various - points during processing as well as log messages about every graph - transformation that took place. -* For `--v=2`, the converter additionally generates log messages about graph - transformations that were considered but not performed. - ### Graph "video" logging When `--dump_graphviz=` is used (see the section on [graph diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md index 9e99287f828c22aa81eb216c087f3261e378fc14..8085ae07489816c38677ff792e7ac71f1a75fa71 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -203,17 +203,11 @@ have. graph transformations on them, at the cost of no longer faithfully matching inference and training arithmetic. -## Logging flags - -The following are standard Google logging flags: +* `--quantize_weights`. Type: boolean. Default: false. Store weights as + quantized weights followed by dequantize operations. Computation is still + done in float, but reduces model size (at the cost of accuracy and latency). -* `--logtostderr` redirects Google logging to standard error, typically making - it visible in a terminal. -* `--v` sets verbose logging levels (for debugging purposes). Defined levels: - * `--v=1`: log all graph transformations that did make a change on the - graph. - * `--v=2`: log all graph transformations that did *not* make a change on - the graph. +## Logging flags The following flags allow to generate graph visualizations of the actual graph at various points during transformations: diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index f0fd638a618c75c75d336a746f9b1d8dccaea470..a7841a685528fb18bb08f1943278339a2daec16a 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -1,69 +1,202 @@ -# TensorFlow Lite Optimizing Converter (TOCO) Python API reference +# TensorFlow Lite Optimizing Converter & Interpreter Python API reference -This page provides examples on how to use TOCO via the Python API. It is -complemented by the following documents: +This page provides examples on how to use TOCO and the TensorFlow Lite +interpreter via the Python API. It is complemented by the following documents: * [README](../README.md) * [Command-line examples](cmdline_examples.md) * [Command-line glossary](cmdline_reference.md) +Table of contents: + +* [High-level overview](#high-level-overview) +* [API](#api) +* [Basic examples](#basic) + * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess) + * [Exporting a GraphDef from file](#basic-graphdef-file) + * [Exporting a SavedModel](#basic-savedmodel) +* [Complex examples](#complex) + * [Exporting a quantized GraphDef](#complex-quant) +* [TensorFlow Lite Python interpreter](#interpreter) + * [Using the interpreter from a model file](#interpreter-file) + * [Using the interpreter from model data](#interpreter-data) + ## High-level overview While the TensorFlow Lite Optimizing Converter can be used from the command -line, it is often convenient to use it as part of Python model build and +line, it is often convenient to use it as part of a Python model build and training script. This is so that conversion can be part of your model development pipeline. This allows you to know early and often that you are designing a model that can be targeted to devices with mobile. ## API -In Python you can run `help(tf.contrib.lite)` to get documentation on functions. -In particular, `tf.contrib.lite.toco_convert` presents a simple API and -`tf.contrib.lite.toco_from_protos` allows more detailed control of TOCO using -the protobuf interface to TOCO. +The API for converting TensorFlow models to TensorFlow Lite is +`tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is +`tf.contrib.lite.Interpreter`. + +`TocoConverter` provides class methods based on the original format of the +model. `TocoConverter.from_session()` is available for GraphDefs. +`TocoConverter.from_saved_model()` is available for SavedModels. Example usages +for simple float-point models are shown in [Basic Examples](#basic). Examples +usages for more complex models is shown in [Complex Examples](#complex). + +**NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python +interpreter when the conversion fails. This will be remedied as soon as +possible. + +## Basic examples + +The following section shows examples of how to convert a basic float-point model +from each of the supported data formats into a TensorFlow Lite FlatBuffers. -## Example +### Exporting a GraphDef from tf.Session -In particular, here we show creating a simple model and converting it to a -TensorFlow Lite Model. +The following example shows how to convert a TensorFlow GraphDef into a +TensorFlow Lite FlatBuffer from a `tf.Session` object. ```python import tensorflow as tf img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) -val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3)) +val = img + var out = tf.identity(val, name="out") + with tf.Session() as sess: - tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) - open("test.tflite", "wb").write(tflite_model) + converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) ``` -**NOTE** Currently, the TOCO command will cause a fatal error to the Python -interpreter when TOCO conversion fails. This will be remedied as soon as -possible. +### Exporting a GraphDef from file + +The following example shows how to convert a TensorFlow GraphDef into a +TensorFlow Lite FlatBuffer when the GraphDef is stored in a file. Both `.pb` and +`.pbtxt` files are accepted. + +The example uses +[Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz). +The function only supports GraphDefs frozen via +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). + +```python +import tensorflow as tf + +graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb" +input_arrays = ["input"] +output_arrays = ["MobilenetV1/Predictions/Softmax"] + +converter = tf.contrib.lite.TocoConverter.from_frozen_graph( + graph_def_file, input_arrays, output_arrays) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + +### Exporting a SavedModel + +The following example shows how to convert a SavedModel into a TensorFlow Lite +FlatBuffer. + +```python +import tensorflow as tf + +converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + +For more complex SavedModels, the optional parameters that can be passed into +`TocoConverter.from_saved_model()` are `input_arrays`, `input_shapes`, +`output_arrays`, `tag_set` and `signature_key`. Details of each parameter are +available by running `help(tf.contrib.lite.TocoConverter)`. -## Example 2: Export with variables +## Complex examples -If a model has variables, they need to be turned into constants. This process is -known as freezing, and it can actually be accomplished with +For models where the default value of the attributes is not sufficient, the +attribute's values should be set before calling `convert()`. In order to call +any constants use `tf.contrib.lite.constants.` as seen below with +`QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TocoConverter)` in the Python +terminal for detailed documentation on the attributes. + +Although the examples are demonstrated on GraphDefs containing only constants. +The same logic can be applied irrespective of the input data format. + +### Exporting a quantized GraphDef + +The following example shows how to convert a quantized model into a TensorFlow +Lite FlatBuffer. ```python import tensorflow as tf img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) -var = tf.get_variable("weights", dtype=tf.float32, shape=(1,64,64,3)) -val = img + var +const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +val = img + const +out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output") + +with tf.Session() as sess: + converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8 + input_arrays = converter.get_input_arrays() + converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) +``` + +## TensorFlow Lite Python interpreter + +### Using the interpreter from a model file + +The following example shows how to use the TensorFlow Lite Python interpreter +when provided a TensorFlow Lite FlatBuffer file. The example also demonstrates +how to run inference on random input data. Run +`help(tf.contrib.lite.Interpreter)` in the Python terminal to get detailed +documentation on the interpreter. + +```python +import numpy as np +import tensorflow as tf -def canonical_name(x): - return x.name.split(":")[0] +# Load TFLite model and allocate tensors. +interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite") +interpreter.allocate_tensors() +# Get input and output tensors. +input_details = interpreter.get_input_details() +output_details = interpreter.get_output_details() + +# Test model on random input data. +input_shape = input_details[0]['shape'] +input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) +interpreter.set_tensor(input_details[0]['index'], input_data) + +interpreter.invoke() +output_data = interpreter.get_tensor(output_details[0]['index']) +print(output_data) +``` + +### Using the interpreter from model data + +The following example shows how to use the TensorFlow Lite Python interpreter +when starting with the TensorFlow Lite Flatbuffer model previously loaded. This +example shows an end-to-end use case, starting from building the TensorFlow +model. + +```python +import numpy as np +import tensorflow as tf + +img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) +const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +val = img + const out = tf.identity(val, name="out") + with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - out_tensors = [out] - frozen_graphdef = tf.graph_util.convert_variables_to_constants( - sess, sess.graph_def, map(canonical_name, out_tensors)) - tflite_model = tf.contrib.lite.toco_convert( - frozen_graphdef, [img], out_tensors) - open("converted_model.tflite", "wb").write(tflite_model) + converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + tflite_model = converter.convert() + +# Load TFLite model and allocate tensors. +interpreter = tf.contrib.lite.Interpreter(model_content=tflite_model) +interpreter.allocate_tensors() ``` diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc index 076415ece8c1039caa32e947fe54ab3e101bec9e..8ca2cd66ac6377a70a4c504ac006dc0388b88bf7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -46,8 +46,9 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 && - conv_op->stride_height == 1) { - // 1x1 unstrided conv does not need an im2col array. + conv_op->stride_height == 1 && conv_op->dilation_width_factor == 1 && + conv_op->dilation_height_factor == 1) { + // 1x1 unstrided undilated conv does not need an im2col array. return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 4e3ea721820cc6ff9638b3c8d487fd4800940122..1bc7557d46cfa5e1b27468d2da271e75fd491d36 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -139,6 +139,7 @@ DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits); DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax) DECLARE_GRAPH_TRANSFORMATION(Quantize) +DECLARE_GRAPH_TRANSFORMATION(QuantizeWeights) DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp) DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert) DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) @@ -182,6 +183,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 437e30a91803bfc847bf246875fa2924b7c0d3fe..bda6dce22be0f0ca83eb8339ad17573b0267c18c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -188,6 +188,32 @@ bool HardcodeMinMaxFromFirstInput(Model* model, Operator* op) { return true; } +bool HardcodeMinMaxForSelect(Model* model, Operator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array_1 = model->GetArray(op->inputs[1]); + if (!input_array_1.minmax) { + return false; + } + const auto& input_array_2 = model->GetArray(op->inputs[2]); + if (!input_array_2.minmax) { + return false; + } + + const auto& input_minmax_1 = input_array_1.GetMinMax(); + const auto& input_minmax_2 = input_array_2.GetMinMax(); + + CHECK_EQ(input_minmax_1.min, input_minmax_2.min); + CHECK_EQ(input_minmax_1.max, input_minmax_2.max); + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = input_minmax_1.min; + output_minmax.max = input_minmax_1.max; + return true; +} + bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min, double max) { CHECK_EQ(op->outputs.size(), 1); @@ -336,6 +362,8 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForAverageOrMaxPool(model, op); break; + case OperatorType::kResizeBilinear: + case OperatorType::kSlice: case OperatorType::kStridedSlice: case OperatorType::kSqueeze: case OperatorType::kTensorFlowReshape: @@ -345,7 +373,9 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { case OperatorType::kMean: changed = HardcodeMinMaxFromFirstInput(model, op); break; - + case OperatorType::kSelect: + changed = HardcodeMinMaxForSelect(model, op); + break; case OperatorType::kLogistic: // We hardcode quantization_params to: zero_point=0, scale=1/256. // This choice of minmax is the one that is equivalent to that. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc index 3f768bfee12ebe31ebeb72855eb67ec03d5bcf8c..5b6a984ee143a6007471b165510030cd3ad3f73c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc @@ -33,9 +33,10 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { return false; } - // Already a compact LstmCell with LstmCellOperator::NUM_INPUTS of inputs, - // do not need to merge cell inputs. - if (src_op->inputs.size() == LstmCellOperator::NUM_INPUTS) { + // Already a compact LstmCell. Do not need to merge cell inputs. + const auto* src_lstm_op = static_cast(src_op); + if (src_lstm_op->kernel_type != LstmCellOperator::KERNEL_FULL || + src_lstm_op->inputs.size() != kExtendedLstmInputCount) { return false; } @@ -136,6 +137,7 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { // Emplace a new LSTM cell operator (use basic 5 inputs kernel). auto lstm_cell_op = absl::make_unique(); + lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_BASIC; // Compact LstmCell's 5 inputs. lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index 8e66323bd769ca166d6b521c5b7b2f1cb944b0a2..e6e3dfa1de9c9fdd5e759fd547d11a7b8c95d837 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -33,9 +33,10 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { return false; } - // Already an extended LstmCell with kExtendedLstmInputCount of inputs, - // do not need to split cell inputs. - if (curr_op->inputs.size() == kExtendedLstmInputCount) { + const auto* curr_lstm_op = static_cast(curr_op); + // Already an extended LstmCell. Do not need to split cell inputs. + if (curr_lstm_op->kernel_type != LstmCellOperator::KERNEL_BASIC || + curr_lstm_op->inputs.size() != LstmCellOperator::NUM_INPUTS) { return false; } @@ -56,6 +57,7 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { // Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc). auto lstm_cell_op = absl::make_unique(); + lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_FULL; lstm_cell_op->inputs.resize(kExtendedLstmInputCount); int num_input = model->GetArray(curr_op->inputs[LstmCellOperator::DATA_INPUT]) .shape() diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 6342cf3e8af4d85ad869a5d60a63d62ca2b00588..92d283ca2cc7069f4b80c95ffdadcad81a884cbf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -60,6 +60,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { case OperatorType::kTensorFlowLessEqual: case OperatorType::kTensorFlowGreater: case OperatorType::kTensorFlowGreaterEqual: + case OperatorType::kTensorFlowEqual: + case OperatorType::kTensorFlowNotEqual: // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); break; @@ -163,6 +165,16 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type_x); break; } + case OperatorType::kSparseToDense: { + // Select produces outputs with the same type as their 3rd input + CHECK_EQ(op->inputs.size(), 4); + const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type; + const ArrayDataType data_type_default = + model->GetArray(op->inputs[3]).data_type; + CHECK(data_type == data_type_default); + SetDataTypeForAllOutputs(model, op, data_type); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 0bce183c1897dfba6f2c393ffc0306c054366725..6d51fc8c31e6c86701c3dc1fd07a9a5479114738 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -102,6 +102,7 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) { // Gathers need their parameters changed to the appropriate data type. case OperatorType::kTensorFlowReshape: case OperatorType::kTranspose: + case OperatorType::kSelect: // Reshapes and transposes don't change values. return false; default: @@ -113,6 +114,8 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) { // propagation. bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) { switch (op.type) { + case OperatorType::kSelect: + return input_index == 0; case OperatorType::kGather: // Ignore gather indices. return input_index != 0; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 52b739c5e27536a8f9903e0ba5c0422d469fb6db..170a499d4eeea6f38704dadd5274e52da7ae2817 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -278,7 +278,7 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { << "TransposeConv input shape must have 4 dimensions. Input \"" << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape " << toco::ShapeToString(weights_shape) << "."; - CHECK_EQ(input_shape.dims(3), weights_shape.dims(0)) + CHECK_EQ(input_shape.dims(3), weights_shape.dims(3)) << "Input shape depth and weight depth do not agree"; // Set the output shape according to the specified output shape. @@ -1477,6 +1477,34 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { *output_array.mutable_shape()->mutable_dims() = output_dims; } +void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + + const Array& output_shape_array = model->GetArray(op->inputs[1]); + if (!output_shape_array.has_shape()) return; + CHECK_EQ(output_shape_array.shape().dimensions_count(), 1); + + // Output should not go over four dimensions. + CHECK_LE(output_shape_array.shape().dims(0), 4); + + const string& output_name = op->outputs[0]; + Array& output_array = model->GetArray(output_name); + if (output_array.has_shape()) return; + + CHECK(output_shape_array.data_type == ArrayDataType::kInt32 || + output_shape_array.data_type == ArrayDataType::kInt64); + if (output_shape_array.data_type == ArrayDataType::kInt32) { + *output_array.mutable_shape()->mutable_dims() = + output_shape_array.GetBuffer().data; + } else { + const std::vector& output_shape_data = + output_shape_array.GetBuffer().data; + std::copy( + output_shape_data.begin(), output_shape_data.end(), + std::back_inserter(*output_array.mutable_shape()->mutable_dims())); + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1514,6 +1542,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kCast: case OperatorType::kFloor: case OperatorType::kExp: + case OperatorType::kSin: ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: @@ -1534,6 +1563,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTensorFlowMaximum: case OperatorType::kTensorFlowMinimum: case OperatorType::kTensorFlowGreaterEqual: + case OperatorType::kTensorFlowEqual: + case OperatorType::kTensorFlowNotEqual: ProcessSimpleBinaryOperator(model, op); break; case OperatorType::kAddN: @@ -1699,6 +1730,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 1); ProcessOpWithShapeInput(model, op); break; + case OperatorType::kSparseToDense: + ProcessSparseToDenseOperator(model, + static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index a1ca7371c87f4c95e754c5484ddfc4063c46c184..eca2c701f8bbf889088794c939af7082db0734dd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -45,12 +45,14 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kTensorFlowMinimum || type == OperatorType::kTensorFlowMaximum || type == OperatorType::kLogistic || type == OperatorType::kSoftmax || - type == OperatorType::kLogSoftmax || + type == OperatorType::kLogSoftmax || type == OperatorType::kSlice || + type == OperatorType::kResizeBilinear || type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub || type == OperatorType::kSqueeze || type == OperatorType::kPad || type == OperatorType::kPadV2 || type == OperatorType::kTensorFlowReshape || type == OperatorType::kTanh || type == OperatorType::kMul || + type == OperatorType::kSpaceToBatchND || type == OperatorType::kSpaceToDepth || type == OperatorType::kStridedSlice || type == OperatorType::kDepthToSpace || @@ -59,7 +61,8 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kTensorFlowGreater || type == OperatorType::kTensorFlowGreaterEqual || type == OperatorType::kTensorFlowLess || - type == OperatorType::kTensorFlowLessEqual; + type == OperatorType::kTensorFlowLessEqual || + type == OperatorType::kSelect || type == OperatorType::kArgMax; } const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc new file mode 100644 index 0000000000000000000000000000000000000000..88ea0945e7dd15ba325d34ea3fdbf34ff7d91381 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +// The minimum number of elements a weights array must have to be quantized +// by this transformation. +// TODO(suharshs): Make this minimum size configurable. +const int kWeightsMinSize = 1024; + +// Gets the quantization params from the float array. +void GetQuantizationParamsFromArray(const Array& array, + QuantizationParams* params) { + const std::vector& float_vals = + array.GetBuffer().data; + auto minmax = std::minmax_element(float_vals.begin(), float_vals.end()); + MinMax toco_minmax; + toco_minmax.min = *minmax.first; + toco_minmax.max = *minmax.second; + GetQuantizationParams(ArrayDataType::kUint8, toco_minmax, params); +} + +} // namespace + +bool QuantizeWeights::Run(Model* model, std::size_t op_index) { + const auto op_it = model->operators.begin() + op_index; + Operator* op = op_it->get(); + + // Get the weights tensor, if the current operator has one. + int weights_index; + if (op->type == OperatorType::kConv || + op->type == OperatorType::kDepthwiseConv || + op->type == OperatorType::kFullyConnected) { + weights_index = 1; + } else if (op->type == OperatorType::kLstmCell) { + weights_index = LstmCellOperator::WEIGHTS_INPUT; + } else { + return false; + } + + // Return early if the array isn't a constant param, this can happen in early + // transformation passes until transpose operations following the weight array + // are resolved. + const string weights = op->inputs[weights_index]; + if (!IsConstantParameterArray(*model, weights)) { + return false; + } + + // Return early if the weight tensor is not type float. + Array& weights_array = model->GetArray(weights); + if (weights_array.data_type != ArrayDataType::kFloat) { + return false; + } + + // Return early if the tensor is too small. Small tensors don't take up too + // much space and can result in bad quantization results. + if (weights_array.GetBuffer().data.size() < + kWeightsMinSize) { + return false; + } + + // Quantize the weight tensor to type kUint8. + QuantizationParams params; + GetQuantizationParamsFromArray(weights_array, ¶ms); + QuantizeArray(this, model, weights, ArrayDataType::kUint8, params); + + // Insert a Dequantize operation after the quantized weights tensor. + auto* dequantize_op = new DequantizeOperator; + model->operators.emplace(op_it, dequantize_op); + + // Create a new intermediate tensor to connect the Dequantize op to the + // original op. + const string dequantized_output = + AvailableArrayName(*model, weights + "_dequantized"); + Array& dequantized_output_array = model->GetOrCreateArray(dequantized_output); + dequantized_output_array.data_type = ArrayDataType::kFloat; + + // Connect up the new Dequantize op with the weights and original op. + op->inputs[weights_index] = dequantized_output; + dequantize_op->inputs = {weights}; + dequantize_op->outputs = {dequantized_output}; + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc index 3e021b819fc82d66fb70596a62fd7cee4911d4e8..a950fe6442bc656b725a1f0687f4c024f4fb0f84 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -85,9 +85,11 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, "Removing %s, keeping its non-constant input array %s and removing %s", LogName(*passthru_op), main_input_name, output_name); RerouteEdges(output_name, main_input_name, model); - } else if (IsDiscardableArray(*model, main_input_name)) { + } else if (IsDiscardableArray(*model, main_input_name) && + !IsConstantParameterArray(*model, main_input_name)) { transformation->AddMessageF( - "Removing %s, keeping its output array %s and removing input %s", + "Removing %s, keeping its output array %s and removing non-constant " + "input %s", LogName(*passthru_op), output_name, main_input_name); RerouteEdges(main_input_name, output_name, model); } else { @@ -95,10 +97,23 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, "Cannot remove %s, neither its main input nor its output may be " "discarded", LogName(*passthru_op)); - return false; + if (passthru_op->type != OperatorType::kTensorFlowReshape && + model->GetArray(main_input_name).has_shape()) { + // We can't remove either array but we can remove the op. Converting it to + // a reshape gives us some hope of later on fixing that (either in the + // final runtime or as an additional fixup step). + // + // Note that we don't try to insert copies in place of reshapes as the + // copy itself is a trivial reshape and we'd go into an infinite loop! + transformation->AddMessageF("Replacing with a copy (reshape) instead"); + InsertCopyOperator(model, main_input_name, output_name); + } else { + return false; + } } // Remove the pass-through node. + CHECK_EQ(passthru_it->get(), passthru_op); model->operators.erase(passthru_it); // Remove any array that is no longer used. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc new file mode 100644 index 0000000000000000000000000000000000000000..b35c3e19c43b1c62e6bdbfe379631480e1d41703 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc @@ -0,0 +1,165 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +template +bool Slice(SliceOperator const& op, Array const& input_array, + Array* output_array) { + // Implementation is taken from the tflite kernel. + + CHECK(input_array.data_type == Type); + CHECK(output_array->data_type == Type); + const auto& input_data = input_array.GetBuffer().data; + + // Create a buffer for the output array. + std::vector>& output_data = + output_array->GetMutableBuffer().data; + output_data.resize(RequiredBufferSizeForShape(output_array->shape())); + + std::vector size = op.size; + if (size.size() != op.begin.size()) { + // Broadcast the end positions. + CHECK_EQ(op.size.size(), 1); + int broadcast_size = size[0]; + while (size.size() < op.begin.size()) size.push_back(broadcast_size); + } + + // Calculate begin and end indices along each dimension. + CHECK_LE(op.begin.size(), 4); + CHECK_LE(size.size(), 4); + std::vector begin = op.begin; + std::vector end; + for (int i = 0; i < begin.size(); ++i) { + int dim_size = size[i]; + if (dim_size == -1) { + // -1 means the rest of the dimension. + dim_size = input_array.shape().dims()[i] - begin[i]; + } + CHECK_GE(dim_size, 1); + end.push_back(begin[i] + dim_size - 1); + } + + // Pad out so that we always have 4 dims, makes this loop easier. + while (begin.size() < 4) begin.insert(begin.begin(), 0); + while (end.size() < 4) end.insert(end.begin(), 0); + Shape padded_shape = input_array.shape(); + while (padded_shape.dimensions_count() < 4) { + padded_shape.mutable_dims()->insert(padded_shape.mutable_dims()->begin(), + 1); + } + + auto* out_ptr = output_data.data(); + for (int in_b = begin[0]; in_b <= end[0]; ++in_b) { + for (int in_h = begin[1]; in_h <= end[1]; ++in_h) { + for (int in_w = begin[2]; in_w <= end[2]; ++in_w) { + for (int in_d = begin[3]; in_d <= end[3]; ++in_d) { + *out_ptr++ = + input_data[Offset(padded_shape, {in_b, in_h, in_w, in_d})]; + } + } + } + } + + return true; +} + +} // namespace + +bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + const auto* base_op = it->get(); + if (base_op->type != OperatorType::kSlice) { + return false; + } + + const SliceOperator* op = static_cast(base_op); + + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes. + return false; + } + + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes. + return false; + } + + if (op->begin.empty() || op->size.empty()) { + // Attributes have not resolved yet. + return false; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until the value shape has been resolved. + return false; + } + if (!IsConstantParameterArray(*model, op->inputs[0])) { + // Yield until the value is constant. + return false; + } + + CHECK(!output_array.buffer); + switch (output_array.data_type) { + case ArrayDataType::kFloat: + if (!Slice(*op, input_array, &output_array)) { + return false; + } + break; + case ArrayDataType::kUint8: + if (!Slice(*op, input_array, &output_array)) { + return false; + } + break; + case ArrayDataType::kInt32: + if (!Slice(*op, input_array, &output_array)) { + return false; + } + break; + case ArrayDataType::kInt64: + if (!Slice(*op, input_array, &output_array)) { + return false; + } + break; + default: + LOG(FATAL) << "Unsupported data type input to Slice op with output \"" + << op->outputs[0] << "\""; + break; + } + + // Erase input array if no longer used. + if (IsDiscardableArray(*model, op->inputs[0]) && + CountOpsWithInput(*model, op->inputs[0]) == 1) { + model->EraseArray(op->inputs[0]); + } + + // Erase the operator + model->operators.erase(it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc index 021e9918f2cf22d3854491762c31061832359a46..65132d7d1ef0626e0ad41a88b8e7999a1c1cf684 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -19,6 +19,24 @@ limitations under the License. namespace toco { +int PadAttributeArray(Array* attribute_array, std::vector pad_values, + int mask) { + int attribute_dim_count = attribute_array->shape().dims(0); + int dim_count = pad_values.size(); + if (attribute_dim_count < dim_count) { + Shape strided_slice_shape = Shape({dim_count}); + attribute_array->copy_shape(strided_slice_shape); + Buffer* buffer = + &(attribute_array->GetMutableBuffer()); + buffer->data.resize(RequiredBufferSizeForShape(strided_slice_shape)); + for (int i = attribute_dim_count; i < dim_count; i++) { + buffer->data[i] = pad_values[i]; + mask |= 1 << i; + } + } + return mask; +} + bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { const auto slice_it = model->operators.begin() + op_index; auto* slice_op = slice_it->get(); @@ -37,52 +55,63 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { return false; } - const auto& start_array = model->GetArray(op->inputs[1]); + auto& start_array = model->GetArray(op->inputs[1]); if (!start_array.has_shape()) return false; if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) { // Only 1-4D arrays are supported for now. return false; } - const auto& stop_array = model->GetArray(op->inputs[2]); + auto& stop_array = model->GetArray(op->inputs[2]); if (!stop_array.has_shape()) return false; - const auto& stride_array = model->GetArray(op->inputs[3]); + auto& stride_array = model->GetArray(op->inputs[3]); if (!stride_array.has_shape()) return false; if (!IsConstantParameterArray(*model, op->inputs[1])) return false; if (!IsConstantParameterArray(*model, op->inputs[2])) return false; if (!IsConstantParameterArray(*model, op->inputs[3])) return false; - op->start_indices = start_array.GetBuffer().data; - op->stop_indices = stop_array.GetBuffer().data; - op->strides = stride_array.GetBuffer().data; + int num_input_axes = input_array.shape().dimensions_count(); + int start_indices_size = start_array.shape().dims(0); + int stop_indices_size = stop_array.shape().dims(0); + int stride_indices_size = stride_array.shape().dims(0); - CHECK_GE(op->start_indices.size(), 1); - CHECK_LE(op->start_indices.size(), 4); - CHECK_EQ(op->stop_indices.size(), op->start_indices.size()); - CHECK_EQ(op->strides.size(), op->stop_indices.size()); + CHECK_GE(start_indices_size, 1); + CHECK_LE(start_indices_size, 4); + CHECK_LE(stop_indices_size, 4); + CHECK_LE(stride_indices_size, 4); // The TensorFlow documentation is not explicit on how it handles fewer // supplied indices than dimensions, but they are accepted. We emulate TF's // behavior by fully iterating over each omitted dimension. - int num_input_axes = input_array.shape().dimensions_count(); - CHECK_LE(op->start_indices.size(), num_input_axes) + CHECK_LE(start_indices_size, num_input_axes) << "StridedSlice op requires no more than " << num_input_axes << " start indices"; - CHECK_LE(op->stop_indices.size(), num_input_axes) + CHECK_LE(stop_indices_size, num_input_axes) << "StridedSlice op requires no more than " << num_input_axes << " stop indices"; - CHECK_LE(op->strides.size(), num_input_axes) + CHECK_LE(stride_indices_size, num_input_axes) << "StridedSlice op requires no more than " << num_input_axes << " strides"; - op->PadIndices(num_input_axes); // Ideally, we would remove the input arrays after they have been resolved. // However, we must then reconstitute these input arrays for all supported // export formats. For now, leave the arrays so we don't have to modify our // exporters. Ideally, we wouldn't have op attributes, and would work directly // with the input arrays. + std::vector begin_pad_values(num_input_axes, 0); + op->begin_mask = + PadAttributeArray(&start_array, begin_pad_values, op->begin_mask); + op->end_mask = + PadAttributeArray(&stop_array, input_array.shape().dims(), op->end_mask); + std::vector stride_pad_values(num_input_axes, 1); + PadAttributeArray(&stride_array, stride_pad_values, 0); + + op->start_indices = start_array.GetBuffer().data; + op->stop_indices = stop_array.GetBuffer().data; + op->strides = stride_array.GetBuffer().data; + return true; } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD index 8dcd4adc90b188c745cadb9815c3c46383705833..95e8433be2a332cfce5175f4f65ea0b83d5638c5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD @@ -8,8 +8,8 @@ load( ) tf_cc_test( - name = "resolve_constant_concatenation_test", - srcs = ["resolve_constant_concatenation_test.cc"], + name = "lstm_utils_test", + srcs = ["lstm_utils_test.cc"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", @@ -19,8 +19,20 @@ tf_cc_test( ) tf_cc_test( - name = "lstm_utils_test", - srcs = ["lstm_utils_test.cc"], + name = "quantize_weights_test", + srcs = ["quantize_weights_test.cc"], + deps = [ + "//tensorflow/contrib/lite/toco:graph_transformations", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + +tf_cc_test( + name = "resolve_constant_concatenation_test", + srcs = ["resolve_constant_concatenation_test.cc"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c05eb0929fd775d315fa735b4c9842a7fc024fa8 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include +#include "absl/memory/memory.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +class QuantizeWeightsTest : public ::testing::Test { + protected: + QuantizeWeightsTest() {} + + // The name of the weights input array. + const string kWeightsName = "weights"; + // The zero_point of the values in the input array. + const int kZeroPoint = 128; + + // Prepare a hypothetical TOCO model of a quantizable fully connected float + // layer. + void PrepareModel(Model* model, int elements_per_dim) { + std::vector fc_input_names = {"inputs", kWeightsName}; + + const int kDim = 4; + const int buf_size = std::pow(elements_per_dim, static_cast(kDim)); + auto in_buf = absl::make_unique(buf_size); + // Initialize the array with values from -128.0 to 127.0, since these values + // should be exactly representable by quantization. + for (int i = 0; i < buf_size; i++) { + in_buf[i] = static_cast(i % 256 - kZeroPoint); + } + + for (const string& fc_input_name : fc_input_names) { + Array& in_array = model->GetOrCreateArray(fc_input_name); + in_array.data_type = ArrayDataType::kFloat; + + // Initialize shape for the input array. + Shape* in_array_shape = in_array.mutable_shape(); + std::vector* in_array_shape_dim = in_array_shape->mutable_dims(); + in_array_shape_dim->resize(kDim, elements_per_dim); + auto& in_array_buffer = + in_array.GetMutableBuffer(); + in_array_buffer.data.resize(buf_size); + float* buf_ptr = + in_array.GetMutableBuffer().data.data(); + std::copy(in_buf.get(), in_buf.get() + buf_size, buf_ptr); + } + + auto* fc_op = new FullyConnectedOperator; + fc_op->inputs = fc_input_names; + fc_op->outputs = {"fc_op_outputs"}; + Array& out_array = model->GetOrCreateArray(fc_op->outputs[0]); + out_array.data_type = ArrayDataType::kFloat; + Shape* out_array_shape = out_array.mutable_shape(); + std::vector* out_array_shape_dim = out_array_shape->mutable_dims(); + out_array_shape_dim->resize(kDim, elements_per_dim); + model->operators.push_back(std::unique_ptr(fc_op)); + } +}; + +TEST_F(QuantizeWeightsTest, QuantizedFullyConnected) { + // Test that weight arrays that are large enough are quantized. + Model model; + // 6 elements per dim gives us 1296 elements, which is sufficient to be + // quantized. + PrepareModel(&model, 6); + + // Check the state of the graph before the transformation. + const auto& float_array_map = model.GetArrayMap(); + EXPECT_EQ(float_array_map.size(), 3); + // Before the transformation, all arrays should be type float. + for (const auto& element : float_array_map) { + EXPECT_EQ(element.second->data_type, ArrayDataType::kFloat); + } + const std::vector float_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + + // Invoke the transformation. + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::QuantizeWeights); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + + // Check the state of the graph after the transformation. + const auto& quantized_array_map = model.GetArrayMap(); + EXPECT_EQ(quantized_array_map.size(), 4); + // After the transformation, three arrays should be type float and one array + // should be uint8. + int num_float = 0; + int num_uint8 = 0; + for (const auto& element : quantized_array_map) { + if (element.second->data_type == ArrayDataType::kFloat) { + num_float++; + } else if (element.second->data_type == ArrayDataType::kUint8) { + num_uint8++; + } else { + FAIL() << "Unexpected array type."; + } + } + EXPECT_EQ(num_float, 3); + EXPECT_EQ(num_uint8, 1); + // Ensure that the values were quantized correctly. + const std::vector& quantized_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + for (int i = 0; i < quantized_weight_vals.size(); i++) { + EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i] + kZeroPoint); + } + + // Ensure that a Dequantize operator has been inserted before the + // FullyConnectedLayer. + EXPECT_EQ(model.operators[0]->type, OperatorType::kDequantize); +} + +TEST_F(QuantizeWeightsTest, NotQuantizedFullyConnected) { + // Test that weight arrays that are too small are left untouched. + Model model; + // 5 elements per dim gives us 625 elements, which is NOT sufficient to be + // quantized. + PrepareModel(&model, 5); + + // Check the state of the graph before the transformation. + const auto& float_array_map = model.GetArrayMap(); + EXPECT_EQ(float_array_map.size(), 3); + // Before the transformation, all arrays should be type float. + for (auto it = float_array_map.begin(); it != float_array_map.end(); it++) { + EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat); + } + std::vector float_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + + // Invoke the transformation. + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::QuantizeWeights); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + + // Check the state of the graph after the transformation. + const auto& post_array_map = model.GetArrayMap(); + EXPECT_EQ(post_array_map.size(), 3); + for (auto it = post_array_map.begin(); it != post_array_map.end(); it++) { + EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat); + } + // Ensure that the values remain unchanged. + std::vector const& quantized_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + for (int i = 0; i < quantized_weight_vals.size(); i++) { + EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i]); + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index 3a1d175b9823f085c9b8730caba8bedd7eb87d52..66cfed4ac26969729d1881f11ba6ae74d9817fb5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include #include #include @@ -126,7 +124,7 @@ class ResolveConstantConcatenationTest : public ::testing::Test { Array& in_array = model->GetOrCreateArray(concat_input_name); in_array.data_type = ArrayDataType::kFloat; - // Initialize shape for the input array. + // Initialize shape for the input array. Shape* in_array_shape = in_array.mutable_shape(); std::vector* in_array_shape_dim = in_array_shape->mutable_dims(); for (int i = 0; i < kDim; i++) { diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 1eef173afe596e1385386f8fba6c7c83e81d94d2..c1c2997c6b2a98085e2a9f4e910ac8f6099ab5ef 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -48,6 +48,12 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" +#define TOCO_RETURN_IF_ERROR(...) \ + do { \ + const ::toco::port::Status _status = (__VA_ARGS__); \ + if (!_status.ok()) return _status; \ + } while (0) + using tensorflow::AttrValue; using tensorflow::DT_BOOL; using tensorflow::DT_FLOAT; @@ -130,6 +136,37 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node, return attr.list(); } +Status CheckOptionalAttr(const NodeDef& node, const string& attr_name, + const string& expected_value) { + if (HasAttr(node, attr_name)) { + const string& value = GetStringAttr(node, attr_name); + if (value != expected_value) { + return Status(false, "Unexpected value for attribute '" + attr_name + + "'. Expected '" + expected_value + "'"); + } + } + return Status::OK(); +} +Status CheckOptionalAttr(const NodeDef& node, const string& attr_name, + const tensorflow::DataType& expected_value) { + if (HasAttr(node, attr_name)) { + const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name); + if (value != expected_value) { + return Status(false, "Unexpected value for attribute '" + attr_name + + "'. Expected '" + + tensorflow::DataType_Name(expected_value) + "'"); + } + } + return Status::OK(); +} + +template +Status ExpectValue(const T1& v1, const T2& v2, const string& description) { + if (v1 == v2) return Status::OK(); + return Status(false, absl::StrCat("Unexpected ", description, ": got ", v1, + ", expected ", v2)); +} + ArrayDataType ConvertDataType(tensorflow::DataType dtype) { if (dtype == DT_UINT8) return ArrayDataType::kUint8; @@ -189,6 +226,7 @@ Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { output_array->GetMutableBuffer().data; output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0.f); + CHECK_GE(output_float_data.size(), input_flat_size); if (input_tensor.float_val_size() == 1) { for (int i = 0; i < input_flat_size; i++) { output_float_data[i] = input_tensor.float_val(0); @@ -202,9 +240,13 @@ Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_float_data.data())); } else { - return Status(false, - "Neither input_content nor float_val have the right " - "dimensions for this float tensor"); + return Status( + false, + absl::StrCat("Neither input_content (", + input_tensor.tensor_content().size() / sizeof(float), + ") nor float_val (", input_tensor.float_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this float tensor")); } return Status::OK(); } @@ -221,6 +263,7 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { auto& output_int_data = output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); + CHECK_GE(output_int_data.size(), input_flat_size); if (input_tensor.int_val_size()) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); @@ -230,9 +273,13 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status(false, - "Neither input_content nor int_val have the right dimensions " - "for this uint8 tensor"); + return Status( + false, + absl::StrCat("Neither input_content (", + input_tensor.tensor_content().size() / sizeof(uint8_t), + ") nor int_val (", input_tensor.int_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this uint8 tensor")); } return Status::OK(); } @@ -249,6 +296,7 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { auto& output_int_data = output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); + CHECK_GE(output_int_data.size(), input_flat_size); if (input_tensor.int_val_size()) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); @@ -258,9 +306,13 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status(false, - "Neither input_content nor int_val have the right dimensions " - "for this int32 tensor"); + return Status( + false, + absl::StrCat("Neither input_content (", + input_tensor.tensor_content().size() / sizeof(int32), + ") nor int_val (", input_tensor.int_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this int32 tensor")); } return Status::OK(); } @@ -277,6 +329,7 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { auto& output_int_data = output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); + CHECK_GE(output_int_data.size(), input_flat_size); if (input_tensor.int64_val_size()) { for (int i = 0; i < input_tensor.int64_val_size(); i++) { output_int_data[i] = input_tensor.int64_val(i); @@ -286,9 +339,13 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status(false, - "Neither input_content nor int64_val have the right " - "dimensions for this int64 tensor"); + return Status( + false, + absl::StrCat("Neither input_content (", + input_tensor.tensor_content().size() / sizeof(int64), + ") nor int64_val (", input_tensor.int64_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this int64 tensor")); } return Status::OK(); } @@ -306,6 +363,7 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { output_array->GetMutableBuffer().data; output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()), false); + CHECK_GE(output_bool_data.size(), input_flat_size); if (input_tensor.bool_val_size()) { for (int i = 0; i < input_tensor.bool_val_size(); i++) { output_bool_data[i] = input_tensor.bool_val(i); @@ -322,9 +380,12 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { // So far only encountered that in an array with 1 entry, let's // require that until we encounter a graph where that's not the case. if (output_bool_data.size() != 1) { - return Status(false, - "Neither input_content nor bool_val have the right " - "dimensions for this bool tensor"); + return Status( + false, absl::StrCat("Neither input_content (", + input_tensor.tensor_content().size(), + ") nor bool_val (", input_tensor.bool_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this bool tensor")); } output_bool_data[0] = false; } @@ -340,13 +401,16 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { output_array->mutable_shape()); if (!status.ok()) return status; + if (input_flat_size != input_tensor.string_val_size()) { + return Status(false, + "Input_content string_val doesn't have the right dimensions " + "for this string tensor"); + } + auto& output_string_data = output_array->GetMutableBuffer().data; output_string_data.resize(RequiredBufferSizeForShape(output_array->shape())); - if (input_flat_size != input_tensor.string_val_size()) { - LOG(FATAL) << "Input_content string_val doesn't have the right " - "dimensions for this string tensor."; - } + CHECK_GE(output_string_data.size(), input_flat_size); for (int i = 0; i < input_flat_size; ++i) { output_string_data[i] = input_tensor.string_val(i); } @@ -439,18 +503,16 @@ Status ConvertConstOperator(const NodeDef& node, return status; } -void ConvertConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +Status ConvertConvOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Conv2D"); CheckInputsCount(node, tf_import_flags, 2); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. - if (HasAttr(node, "data_format")) { - CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC"); - } - CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + TOCO_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC")); + TOCO_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT)); const auto& input_name = node.input(0); const auto& weights_name = node.input(1); @@ -475,27 +537,27 @@ void ConvertConvOperator(const NodeDef& node, auto* conv = new ConvOperator; conv->inputs = {input_name, reordered_weights_name}; conv->outputs = {node.name()}; + TOCO_RETURN_IF_ERROR( + Status(HasAttr(node, "strides"), "Missing attribute 'strides'")); const auto& strides = GetListAttr(node, "strides"); - CHECK_EQ(strides.i_size(), 4); - CHECK_EQ(strides.i(0), 1); - CHECK_EQ(strides.i(3), 1); + TOCO_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides")); + TOCO_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)")); + TOCO_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)")); conv->stride_height = strides.i(1); conv->stride_width = strides.i(2); if (HasAttr(node, "dilations")) { const auto& dilations = GetListAttr(node, "dilations"); - CHECK_EQ(dilations.i_size(), 4); - CHECK_EQ(dilations.i(0), 1) - << "Can only import Conv ops with dilation along the height (1st) or " - "width (2nd) axis. TensorFlow op \"" - << node.name() << "\" had dilations:[ " << dilations.i(0) << ", " - << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3) - << "]."; - CHECK_EQ(dilations.i(3), 1) - << "Can only import Conv ops with dilation along the height (1st) or " - "width (2nd) axis. TensorFlow op \"" - << node.name() << "\" had dilations:[ " << dilations.i(0) << ", " - << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3) - << "]."; + TOCO_RETURN_IF_ERROR( + ExpectValue(dilations.i_size(), 4, "number of dilations")); + if (dilations.i(0) != 1 || dilations.i(3) != 1) { + return Status( + false, absl::StrCat( + "Can only import Conv ops with dilation along the height " + "(1st) or width (2nd) axis. TensorFlow op \"", + node.name(), "\" had dilations:[ ", dilations.i(0), ", ", + dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), + "].")); + } conv->dilation_height_factor = dilations.i(1); conv->dilation_width_factor = dilations.i(2); } else { @@ -508,9 +570,11 @@ void ConvertConvOperator(const NodeDef& node, } else if (padding == "VALID") { conv->padding.type = PaddingType::kValid; } else { - LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + return Status(false, "Bad padding (only SAME and VALID are supported)"); } model->operators.emplace_back(conv); + + return Status::OK(); } void ConvertDepthwiseConvOperator(const NodeDef& node, @@ -587,7 +651,14 @@ void ConvertSpaceToDepthOperator(const NodeDef& node, CHECK_EQ(node.op(), "SpaceToDepth"); CheckInputsCount(node, tf_import_flags, 1); - CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + tensorflow::DataType dtype = GetDataTypeAttr(node, "T"); + if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 && + dtype != DT_INT64) { + const auto* enum_descriptor = tensorflow::DataType_descriptor(); + LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:" + << enum_descriptor->FindValueByNumber(dtype)->name() << ". " + << "T must be one of {DT_FLOAT, DT_INT8, DT_INT32, DT_INT64}."; + } auto* op = new SpaceToDepthOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); @@ -629,81 +700,6 @@ void ConvertRandomUniform(const NodeDef& node, model->operators.emplace_back(std::move(op)); } -void ConvertReluOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Relu"); - CheckInputsCount(node, tf_import_flags, 1); - const auto& input_name = node.input(0); - auto* relu = new ReluOperator; - relu->inputs.push_back(input_name); - relu->outputs.push_back(node.name()); - model->operators.emplace_back(relu); -} - -void ConvertRelu6Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Relu6"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new Relu6Operator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertLogOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Log"); - CheckInputsCount(node, tf_import_flags, 1); - - auto op = absl::make_unique(); - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(std::move(op)); -} - -void ConvertLogisticOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sigmoid"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new LogisticOperator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertTanhOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Tanh"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new TanhOperator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertDivOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK(node.op() == "Div" || node.op() == "RealDiv"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new DivOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertIdentityOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -760,38 +756,6 @@ void ConvertFakeQuantWithMinMaxVars( model->operators.emplace_back(op); } -void ConvertNegOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Neg"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new NegOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertRsqrtOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Rsqrt"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowRsqrtOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSqrtOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sqrt"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowSqrtOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertSqueezeOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -813,66 +777,6 @@ void ConvertSqueezeOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertSquareOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Square"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowSquareOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAddOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Add"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new AddOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAddNOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "AddN"); - const int num_inputs = GetInputsCount(node, tf_import_flags); - auto* op = new AddNOperator; - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Mul"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new MulOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSubOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sub"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new SubOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSumOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -888,67 +792,6 @@ void ConvertSumOperator(const NodeDef& node, } } -void ConvertTileOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Tile"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowTileOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSliceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Slice"); - CheckInputsCount(node, tf_import_flags, 3); - auto* op = new SliceOperator; - for (int i = 0; i < 3; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertPadOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Pad"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new PadOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertPadV2Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "PadV2"); - CheckInputsCount(node, tf_import_flags, 3); - auto* op = new PadV2Operator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->inputs.push_back(node.input(2)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertShapeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Shape"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowShapeOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSplitOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -966,18 +809,6 @@ void ConvertSplitOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertMergeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Merge"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMergeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSwitchOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1007,18 +838,6 @@ void ConvertSoftmaxOperator(const NodeDef& node, model->operators.emplace_back(softmax); } -void ConvertLogSoftmaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "LogSoftmax"); - CheckInputsCount(node, tf_import_flags, 1); - const auto& input_name = node.input(0); - auto* log_softmax = new LogSoftmaxOperator; - log_softmax->inputs.push_back(input_name); - log_softmax->outputs.push_back(node.name()); - model->operators.emplace_back(log_softmax); -} - void ConvertLRNOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1115,17 +934,6 @@ void ConvertAvgPoolOperator(const NodeDef& node, model->operators.emplace_back(avgpool); } -void ConvertReshapeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Reshape"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowReshapeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertBatchMatMulOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1188,37 +996,12 @@ void ConvertConcatOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertAllOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "All"); - auto* op = new TensorFlowAllOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAssertOperator(const NodeDef& node, +// This method supports simple operators without additional attributes. +template +void ConvertSimpleOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "Assert"); - auto* op = new TensorFlowAssertOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertLessOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Less"); - auto* op = new TensorFlowLessOperator; + auto* op = new Op; const int num_inputs = GetInputsCount(node, tf_import_flags); for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); @@ -1227,43 +1010,13 @@ void ConvertLessOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertLessEqualOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "LessEqual"); - auto* op = new TensorFlowLessEqualOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertGreaterOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Greater"); - auto* op = new TensorFlowGreaterOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertGreaterEqualOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "GreaterEqual"); - auto* op = new TensorFlowGreaterEqualOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); +// This method supports simple operators without additional attributes. +template +void ConvertSimpleOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CheckInputsCount(node, tf_import_flags, NumInputs); + ConvertSimpleOperator(node, tf_import_flags, model); } void ConvertMaxOperator(const NodeDef& node, @@ -1296,29 +1049,6 @@ void ConvertMinOperator(const NodeDef& node, } } -void ConvertMaximumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Maximum"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMaximumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertMinimumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Minimum"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMinimumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertUnsupportedOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1341,22 +1071,12 @@ void ConvertUnsupportedOperator(const NodeDef& node, for (int i = 0; i < output_types.type_size(); ++i) { op->output_data_types.push_back(ConvertDataType(output_types.type(i))); } + } else if (HasAttr(node, "Tout")) { + const auto& output_type = GetDataTypeAttr(node, "Tout"); + op->output_data_types.push_back(ConvertDataType(output_type)); } } -void ConvertSelectOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 3); - - auto* op = new SelectOperator; - for (const auto& input : node.input()) { - op->inputs.push_back(input); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertStridedSliceOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1635,17 +1355,6 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertExpOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Exp"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new ExpOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertMeanOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1736,11 +1445,13 @@ void ConvertTransposeConvOperator(const NodeDef& node, if (existing_transpose) { CHECK(existing_transpose->type == OperatorType::kTranspose); } else { - // Transpose weights from HWIO order to OHWI order, which is more efficient - // for computation + // Transpose weights from HWOI order to OHWI order, which is more efficient + // for computation. (Note that TensorFlow considers the order as HWIO + // because they consider this a backward conv, inverting the sense of + // input/output.) TransposeOperator* transpose = new TransposeOperator; string perm_array = CreateConstArray( - model, node.name() + "_transpose_perm", {3, 0, 1, 2}); + model, node.name() + "_transpose_perm", {2, 0, 1, 3}); transpose->inputs = {weights_name, perm_array}; transpose->outputs = {transposed_weights_name}; model->operators.emplace_back(transpose); @@ -1759,53 +1470,6 @@ void ConvertTransposeConvOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertExpandDimsOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "ExpandDims"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new ExpandDimsOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFillOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Fill"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FillOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFloorDivOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "FloorDiv"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FloorDivOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFloorModOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "FloorMod"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FloorModOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertRangeOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1826,17 +1490,6 @@ void ConvertRangeOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertRankOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Rank"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new RankOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertStackOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1857,17 +1510,6 @@ void ConvertStackOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertTransposeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Transpose"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TransposeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} // Some TensorFlow ops only occur in graph cycles, representing // control flow. We do not currently support control flow, so we wouldn't @@ -2090,6 +1732,24 @@ void ConvertDynamicStitchOperator(const NodeDef& node, model->operators.emplace_back(op.release()); } +void ConvertSparseToDenseOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "SparseToDense"); + CheckInputsCount(node, tf_import_flags, 4); + + auto* op = new SparseToDenseOperator; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + op->outputs.push_back(node.name()); + + op->validate_indices = HasAttr(node, "validate_indices") + ? GetBoolAttr(node, "validate_indices") + : true; + model->operators.emplace_back(op); +} + } // namespace namespace internal { @@ -2101,7 +1761,7 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, if (node.op() == "Const") { return ConvertConstOperator(node, tf_import_flags, model); } else if (node.op() == "Conv2D") { - ConvertConvOperator(node, tf_import_flags, model); + return ConvertConvOperator(node, tf_import_flags, model); } else if (node.op() == "Conv2DBackpropInput") { ConvertTransposeConvOperator(node, tf_import_flags, model); } else if (node.op() == "DepthwiseConv2dNative") { @@ -2113,25 +1773,26 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "BiasAdd") { ConvertBiasAddOperator(node, tf_import_flags, model); } else if (node.op() == "Relu") { - ConvertReluOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Relu6") { - ConvertRelu6Operator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Sigmoid") { - ConvertLogisticOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Tanh") { - ConvertTanhOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "MaxPool") { ConvertMaxPoolOperator(node, tf_import_flags, model); } else if (node.op() == "AvgPool") { ConvertAvgPoolOperator(node, tf_import_flags, model); } else if (node.op() == "Reshape") { - ConvertReshapeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "BatchMatMul") { ConvertBatchMatMulOperator(node, tf_import_flags, model); } else if (node.op() == "MatMul") { ConvertMatMulOperator(node, tf_import_flags, model); } else if (node.op() == "Div" || node.op() == "RealDiv") { - ConvertDivOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Identity" || node.op() == "CheckNumerics" || node.op() == "StopGradient") { ConvertIdentityOperator(node, tf_import_flags, model); @@ -2140,27 +1801,31 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "FakeQuantWithMinMaxArgs") { ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model); } else if (node.op() == "Neg") { - ConvertNegOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Rsqrt") { - ConvertRsqrtOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Squeeze") { ConvertSqueezeOperator(node, tf_import_flags, model); } else if (node.op() == "Sqrt") { - ConvertSqrtOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Square") { - ConvertSquareOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Add") { - ConvertAddOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "AddN") { - ConvertAddNOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Mul") { - ConvertMulOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Sub") { - ConvertSubOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Sum") { ConvertSumOperator(node, tf_import_flags, model); } else if (node.op() == "Tile") { - ConvertTileOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Concat" || node.op() == "ConcatV2") { ConvertConcatOperator(node, tf_import_flags, model); } else if (node.op() == "LRN") { @@ -2168,41 +1833,50 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "Softmax") { ConvertSoftmaxOperator(node, tf_import_flags, model); } else if (node.op() == "Log") { - ConvertLogOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "LogSoftmax") { - ConvertLogSoftmaxOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "All") { - ConvertAllOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Assert") { - ConvertAssertOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Less") { - ConvertLessOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "LessEqual") { - ConvertLessEqualOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Greater") { - ConvertGreaterOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "GreaterEqual") { - ConvertGreaterEqualOperator(node, tf_import_flags, model); + ConvertSimpleOperator( + node, tf_import_flags, model); } else if (node.op() == "Max") { ConvertMaxOperator(node, tf_import_flags, model); } else if (node.op() == "Min") { ConvertMinOperator(node, tf_import_flags, model); } else if (node.op() == "Maximum") { - ConvertMaximumOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Minimum") { - ConvertMinimumOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Merge") { - ConvertMergeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Pad") { - ConvertPadOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "PadV2") { - ConvertPadV2Operator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "StridedSlice") { ConvertStridedSliceOperator(node, tf_import_flags, model); } else if (node.op() == "Shape") { - ConvertShapeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Slice") { - ConvertSliceOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Split") { ConvertSplitOperator(node, tf_import_flags, model); } else if (node.op() == "Switch") { @@ -2239,25 +1913,25 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "NextIteration") { ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model); } else if (node.op() == "ExpandDims") { - ConvertExpandDimsOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Fill") { - ConvertFillOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "FloorDiv") { - ConvertFloorDivOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "FloorMod") { - ConvertFloorModOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Range") { ConvertRangeOperator(node, tf_import_flags, model); } else if (node.op() == "Rank") { - ConvertRankOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Stack" || node.op() == "Pack") { ConvertStackOperator(node, tf_import_flags, model); } else if (node.op() == "Transpose") { - ConvertTransposeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "ArgMax") { ConvertArgMaxOperator(node, tf_import_flags, model); } else if (node.op() == "Exp") { - ConvertExpOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "TopK" || node.op() == "TopKV2") { ConvertTopKV2Operator(node, tf_import_flags, model); } else if (node.op() == "DynamicPartition") { @@ -2267,8 +1941,20 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, ConvertDynamicStitchOperator(node, tf_import_flags, model); } else if (node.op() == "RandomUniform") { ConvertRandomUniform(node, tf_import_flags, model); + } else if (node.op() == "Sin") { + ConvertSimpleOperator(node, tf_import_flags, model); + } else if (node.op() == "Log") { + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Select") { - ConvertSelectOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); + } else if (node.op() == "SparseToDense") { + ConvertSparseToDenseOperator(node, tf_import_flags, model); + } else if (node.op() == "Equal") { + ConvertSimpleOperator(node, tf_import_flags, + model); + } else if (node.op() == "NotEqual") { + ConvertSimpleOperator(node, tf_import_flags, + model); } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index 5dc78f73ad2e2ab6f1fcb1ee430513488ce47027..835676662b9cb7ed20e578e2a35747a64ba443dc 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -148,10 +148,11 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) { NodeDef node; BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node); auto status = ImportNode(node); - EXPECT_THAT(status.error_message(), - ::testing::MatchesRegex( - "Neither input_content nor .*_val have the right dimensions " - "for this .* tensor .while processing node 'Node1'.")); + EXPECT_THAT( + status.error_message(), + ::testing::MatchesRegex( + "Neither input_content .0. nor .*_val .0. have the right " + "dimensions .8. for this .* tensor .while processing node 'Node1'.")); } INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest, ::testing::ValuesIn(TestTypes())); diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 47f8db597846bf45df5ddbfd638e64f4bb9bab39..2f43adb07b1c9dc9645942ce6ec868595704baa5 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -78,6 +78,7 @@ enum class OperatorType { kFloor, kGather, kResizeBilinear, + kSin, kSpaceToBatchND, kStack, kBatchToSpaceND, @@ -134,6 +135,9 @@ enum class OperatorType { // special nodes in the graph to shuffle axes. kReorderAxes, kSelect, + kSparseToDense, + kTensorFlowEqual, + kTensorFlowNotEqual, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -151,6 +155,7 @@ enum class AxesOrder { k1HWO, // Our standard for DepthwiseConv weights kHWIM, // TensorFlow DepthwiseConv weights kNHWC, // TensorFlow activations + kHWOI, // TensorFlow back-prop conv weights }; // The type of the scalars in an array. @@ -525,7 +530,15 @@ struct LstmCellOperator : Operator { ACTIV_TEMP = 3, NUM_OUTPUTS = 4 }; - LstmCellOperator() : Operator(OperatorType::kLstmCell) {} + enum KernelType { + KERNEL_BASIC = 0, + KERNEL_FULL = 1, + }; + + LstmCellOperator() + : Operator(OperatorType::kLstmCell), kernel_type(KERNEL_BASIC) {} + + KernelType kernel_type; }; // Element-wise multiplication operator. @@ -618,6 +631,17 @@ struct TanhOperator : Operator { TanhOperator() : Operator(OperatorType::kTanh) {} }; +// Element-wise Sin operator: +// x -> Sin(x) = sin(x) +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Sin +struct SinOperator : Operator { + SinOperator() : Operator(OperatorType::kSin) {} +}; + // Element-wise addition operator. // // Inputs: @@ -1337,6 +1361,22 @@ struct TensorFlowGreaterEqualOperator : Operator { : Operator(OperatorType::kTensorFlowGreaterEqual) {} }; +// TensorFlow Equal equivalent. Refer to TensorFlow documentation for +// details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowEqualOperator : Operator { + TensorFlowEqualOperator() : Operator(OperatorType::kTensorFlowEqual) {} +}; + +// TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for +// details. +struct TensorFlowNotEqualOperator : Operator { + TensorFlowNotEqualOperator() : Operator(OperatorType::kTensorFlowNotEqual) {} +}; + // Global max reduction: computes the max of all of entries in the input array. // Thus the output is "0-dimensional": it consists of a single scalar value. // @@ -1586,13 +1626,26 @@ struct DynamicStitchOperator : Operator { int num_partitions; }; +// SparseToDense operator: +// +// Inputs: +// Inputs[0]: required: sparse_indices. +// Inputs[1]: required: output_shape. +// Inputs[2]: required: sparse_values. +// +// TensorFlow equivalent: SparseToDense. +struct SparseToDenseOperator : Operator { + SparseToDenseOperator() : Operator(OperatorType::kSparseToDense) {} + bool validate_indices; +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are // offsets from the start of the workspace buffer, expressed in bytes. struct Alloc { - int start = 0; - int end = 0; + int64 start = 0; + int64 end = 0; }; inline bool operator<(const Alloc& a, const Alloc& b) { @@ -1817,6 +1870,8 @@ class Model { } const ArrayMap& GetArrayMap() const { return arrays; } + int64 ArithmeticOpsCount() const { return ops_count; } + // Optional arrays are used for optional tensors, // these tensors do not have data, but with reserved names as op inputs. std::set optional_arrays; @@ -1833,6 +1888,8 @@ class Model { std::size_t transient_data_size = 0; // For code-generation only: required alignment of the transient_data buffer std::size_t transient_data_alignment = 0; + // Arithmatic operations performed in the model. + int64 ops_count = 0; private: // The associative array mapping names to Array's. diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 7bbeab7c9d1e42d28f221f1a1134d9d05fe6ab51..4c9f1aa4b0274b5123bb3baa9b9fca1463bda4c3 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -48,7 +48,7 @@ bool ParseModelFlagsFromCommandLineFlags( "that information from the input file."), Flag("input_arrays", parsed_flags.input_arrays.bind(), parsed_flags.input_arrays.default_value(), - "Names of the output arrays, comma-separated. If not specified, " + "Names of the input arrays, comma-separated. If not specified, " "will try to read that information from the input file."), Flag("output_array", parsed_flags.output_array.bind(), parsed_flags.output_array.default_value(), @@ -83,7 +83,7 @@ bool ParseModelFlagsFromCommandLineFlags( "Deprecated: use --input_data_types instead. Input array type, if " "not already provided in the graph. " "Typically needs to be specified when passing arbitrary arrays " - "to --input_array."), + "to --input_arrays."), Flag("input_data_types", parsed_flags.input_data_types.bind(), parsed_flags.input_data_types.default_value(), "Input arrays types, comma-separated, if not already provided in " @@ -124,14 +124,6 @@ bool ParseModelFlagsFromCommandLineFlags( parsed_flags.model_checks.default_value(), "A list of model checks to be applied to verify the form of the " "model. Applied after the graph transformations after import."), - Flag("graphviz_first_array", parsed_flags.graphviz_first_array.bind(), - parsed_flags.graphviz_first_array.default_value(), - "If set, defines the start of the sub-graph to be dumped to " - "GraphViz."), - Flag( - "graphviz_last_array", parsed_flags.graphviz_last_array.bind(), - parsed_flags.graphviz_last_array.default_value(), - "If set, defines the end of the sub-graph to be dumped to GraphViz."), Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(), parsed_flags.dump_graphviz.default_value(), "Dump graphviz during LogDump call. If string is non-empty then " @@ -180,8 +172,6 @@ bool ParseModelFlagsFromCommandLineFlags( if (!tensorflow::Flags::Parse(argc, argv, flags)) return false; } auto& dump_options = *GraphVizDumpOptions::singleton(); - dump_options.graphviz_first_array = parsed_flags.graphviz_first_array.value(); - dump_options.graphviz_last_array = parsed_flags.graphviz_last_array.value(); dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value(); dump_options.dump_graphviz = parsed_flags.dump_graphviz.value(); diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD index 6c4f8e12cdd5b3222997c4a2b0ac243cc74324e0..93fe756a55d378fa205ff88be5e18aff586e5dca 100644 --- a/tensorflow/contrib/lite/toco/python/BUILD +++ b/tensorflow/contrib/lite/toco/python/BUILD @@ -12,10 +12,11 @@ cc_library( deps = [ "//tensorflow/contrib/lite/toco:model_flags_proto_cc", "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", + "//tensorflow/contrib/lite/toco:toco_graphviz_dump_options", "//tensorflow/contrib/lite/toco:toco_port", "//tensorflow/contrib/lite/toco:toco_tooling", "//tensorflow/core:lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -26,7 +27,7 @@ tf_py_wrap_cc( ":toco_python_api", "//tensorflow/contrib/lite/toco:model_flags_proto_cc", "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", - "//util/python:python_headers", + "//third_party/python_runtime:headers", "@com_google_absl//absl/strings", ], ) @@ -41,12 +42,6 @@ py_binary( ], ) -py_binary( - name = "toco_wrapper", - srcs = ["toco_wrapper.py"], - srcs_version = "PY2AND3", -) - tf_py_test( name = "toco_from_protos_test", srcs = ["toco_from_protos_test.py"], diff --git a/tensorflow/contrib/lite/toco/python/toco.i b/tensorflow/contrib/lite/toco/python/toco.i index 3787cba4a371f1893d877daadcfe31e59eb5b3f6..0d2fbdd67b3aa59af9d5f32c4f1693fe044a7efa 100644 --- a/tensorflow/contrib/lite/toco/python/toco.i +++ b/tensorflow/contrib/lite/toco/python/toco.i @@ -24,9 +24,12 @@ namespace toco { // Convert a model represented in `input_contents`. `model_flags_proto` // describes model parameters. `toco_flags_proto` describes conversion // parameters (see relevant .protos for more information). Returns a string -// representing the contents of the converted model. +// representing the contents of the converted model. When extended_return +// flag is set to true returns a dictionary that contains string representation +// of the converted model and some statitics like arithmetic ops count. PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, PyObject* toco_flags_proto_txt_raw, - PyObject* input_contents_txt_raw); + PyObject* input_contents_txt_raw, + bool extended_return = false); } // namespace toco \ No newline at end of file diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc index 153c117d17e4564d7cb0aaea64d792f63a587d91..d93e104038741e6e59608f04115854d611f1f9ae 100644 --- a/tensorflow/contrib/lite/toco/python/toco_python_api.cc +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/python/toco_python_api.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_tooling.h" #include "tensorflow/contrib/lite/toco/toco_types.h" @@ -37,7 +38,7 @@ namespace toco { // sure we input and output bytes rather than unicode strings for Python3. PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, PyObject* toco_flags_proto_txt_raw, - PyObject* input_contents_txt_raw) { + PyObject* input_contents_txt_raw, bool extended_return) { // Use Python C API to validate and convert arguments. In py3 (bytes), // in py2 (str). auto ConvertArg = [&](PyObject* obj, bool* error) { @@ -62,7 +63,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error); if (error) return nullptr; - // Use toco to produce new outputs + // Use TOCO to produce new outputs. toco::ModelFlags model_flags; if (!model_flags.ParseFromString(model_flags_proto_txt)) { LOG(FATAL) << "Model proto failed to parse." << std::endl; @@ -71,6 +72,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, if (!toco_flags.ParseFromString(toco_flags_proto_txt)) { LOG(FATAL) << "Toco proto failed to parse." << std::endl; } + + auto& dump_options = *GraphVizDumpOptions::singleton(); + if (toco_flags.has_dump_graphviz_dir()) { + dump_options.dump_graphviz = toco_flags.dump_graphviz_dir(); + } + if (toco_flags.has_dump_graphviz_include_video()) { + dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video(); + } + + // Convert model. std::unique_ptr model = toco::Import(toco_flags, model_flags, input_contents_txt); toco::Transform(toco_flags, model.get()); @@ -78,6 +89,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, Export(toco_flags, *model, toco_flags.allow_custom_ops(), &output_file_contents_txt); + if (extended_return) { + PyObject* dict = PyDict_New(); + PyDict_SetItemString( + dict, "flatbuffer", + TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(), + output_file_contents_txt.size())); + PyDict_SetItemString(dict, "arithmetic_ops", + PyLong_FromLong(model->ArithmeticOpsCount())); + return dict; + } // Convert arguments back to byte (py3) or str (py2) return TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(), output_file_contents_txt.size()); diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/contrib/lite/toco/python/toco_python_api.h index dc378353f79945f4fbb72305899b2b604be785ad..7e8ad9c1dafa68dd91e4a0eb3bfb742207878c59 100644 --- a/tensorflow/contrib/lite/toco/python/toco_python_api.h +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.h @@ -15,18 +15,21 @@ limitations under the License. #ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ #define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ -#include #include +#include namespace toco { // Convert a model represented in `input_contents`. `model_flags_proto` // describes model parameters. `toco_flags_proto` describes conversion // parameters (see relevant .protos for more information). Returns a string -// representing the contents of the converted model. +// representing the contents of the converted model. When extended_return +// flag is set to true returns a dictionary that contains string representation +// of the converted model and some statitics like arithmetic ops count. PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, PyObject* toco_flags_proto_txt_raw, - PyObject* input_contents_txt_raw); + PyObject* input_contents_txt_raw, + bool extended_return = false); } // namespace toco diff --git a/tensorflow/contrib/lite/toco/python/toco_wrapper.py b/tensorflow/contrib/lite/toco/python/toco_wrapper.py deleted file mode 100644 index 6d6b500d7eccd353f566a4bad76df35e0e849d95..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/python/toco_wrapper.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Wrapper for runninmg toco binary embedded in pip site-package. - -NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/. -It can only install Python "console-scripts." This will work as a console -script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS). -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - - -def main(): - # Pip installs the binary in aux-bin off of main site-package install. - # Just find it and exec, passing all arguments in the process. - # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary. - print("""TOCO from pip install is currently not working on command line. -Please use the python TOCO API or use -bazel run tensorflow/contrib/lite:toco -- from a TensorFlow source dir. -""") - sys.exit(1) - # TODO(aselle): Replace this when we find a way to run toco without - # blowing up executable size. - # binary = os.path.join(tf.__path__[0], 'aux-bin/toco') - # os.execvp(binary, sys.argv) diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 335b496dccdbdb7e342515868e1d7195c98f0351..a2d753657b0bf6c88f5c94a20a1240fb7c13a37c 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -45,14 +45,20 @@ using ::tflite::Tensor; namespace { -details::OperatorKey GetOperatorKey(const ::toco::Operator& op) { +details::OperatorKey GetOperatorKey( + const ::toco::Operator& op, + const std::map>& ops_by_type) { string custom_code; if (op.type == OperatorType::kTensorFlowUnsupported) { const TensorFlowUnsupportedOperator& unsupported_op = static_cast(op); custom_code = unsupported_op.tensorflow_op; } - return details::OperatorKey(op.type, custom_code); + int version = 1; + if (ops_by_type.count(op.type) != 0) { + version = ops_by_type.at(op.type)->GetVersion(op); + } + return details::OperatorKey(op.type, custom_code, version); } } // Anonymous namespace. @@ -74,11 +80,13 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { } } -void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map) { +void LoadOperatorsMap( + const Model& model, OperatorsMap* operators_map, + const std::map>& ops_by_type) { // First find a list of unique operator types. std::set keys; for (const auto& op : model.operators) { - keys.insert(GetOperatorKey(*op)); + keys.insert(GetOperatorKey(*op, ops_by_type)); } // Now assign indices to them and fill in the map. int index = 0; @@ -185,8 +193,9 @@ Offset>> ExportOperatorCodes( std::map> ordered_opcodes; for (const auto& op : model.operators) { - const details::OperatorKey operator_key = GetOperatorKey(*op); + const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type); int op_index = operators_map.at(operator_key); + int op_version = operator_key.version; string name = HelpfulOperatorTypeName(*op); bool is_builtin = false; @@ -197,7 +206,7 @@ Offset>> ExportOperatorCodes( if (is_builtin) { ordered_opcodes[op_index] = - CreateOperatorCode(*builder, builtin_ops[name], 0); + CreateOperatorCode(*builder, builtin_ops[name], 0, op_version); } else { // This could be a kTensorFlowUnsupported, in which case we should be // able to retrieve the original Tensorflow name from the OperatorKey, or @@ -211,8 +220,9 @@ Offset>> ExportOperatorCodes( if (error_summary) { error_summary->insert(name); } - ordered_opcodes[op_index] = CreateOperatorCode( - *builder, BuiltinOperator_CUSTOM, builder->CreateString(name)); + ordered_opcodes[op_index] = + CreateOperatorCode(*builder, BuiltinOperator_CUSTOM, + builder->CreateString(name), op_version); } } @@ -244,7 +254,7 @@ Offset>> ExportOperators( outputs.push_back(tensors_map.at(output)); } - int op_index = operators_map.at(GetOperatorKey(*op)); + int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type)); // This is a custom op unless we can find it in ops_by_type, and even then // it could be a custom op (such as kTensorFlowUnsupported). @@ -279,15 +289,20 @@ Offset>> ExportBuffers( void Export(const Model& model, bool allow_custom_ops, string* output_file_contents) { - flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); - const auto ops_by_type = BuildOperatorByTypeMap(); + Export(model, allow_custom_ops, output_file_contents, ops_by_type); +} + +void Export( + const Model& model, bool allow_custom_ops, string* output_file_contents, + const std::map>& ops_by_type) { + flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); details::TensorsMap tensors_map; details::LoadTensorsMap(model, &tensors_map); details::OperatorsMap operators_map; - details::LoadOperatorsMap(model, &operators_map); + details::LoadOperatorsMap(model, &operators_map, ops_by_type); std::vector buffers_to_write; Array empty_array; @@ -301,6 +316,7 @@ void Export(const Model& model, bool allow_custom_ops, auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, &builder, &error_summary); const string fake_quant_operation_name = "FAKE_QUANT"; + if (error_summary.count(fake_quant_operation_name) != 0) { LOG(ERROR) << fake_quant_operation_name @@ -312,12 +328,29 @@ void Export(const Model& model, bool allow_custom_ops, error_summary.erase(fake_quant_operation_name); } if (!allow_custom_ops && !error_summary.empty()) { - LOG(QFATAL) << "Some of the operators in the model are not supported by " - "the standard TensorFlow Lite runtime. If you have a custom " - "implementation for them you can disable this error with " - "--allow_custom_ops. Here is a list of operators for which " - "you will need custom implementations: " - << absl::StrJoin(error_summary, ", ") << "."; + // Remove ExpandDims and ReorderAxes from unimplemented list unless they + // compose the list. Both ops are removed during graph transformations. + // However, if an op is unimplemented earlier in the model, the graph + // transformation is unable to run because the output shape is not defined. + // This causes unnecessary confusion during model conversion time. + std::set error_summary_final; + for (const auto& op_type : error_summary) { + if (op_type != "ReorderAxes" && op_type != "ExpandDims") { + error_summary_final.insert(op_type); + } + } + if (error_summary_final.empty()) { + error_summary_final = error_summary; + } + + LOG(QFATAL) + << "Some of the operators in the model are not supported by " + "the standard TensorFlow Lite runtime. If you have a custom " + "implementation for them you can disable this error with " + "--allow_custom_ops, or by setting allow_custom_ops=True " + "when calling tf.contrib.lite.toco_convert(). Here is a list " + "of operators for which you will need custom implementations: " + << absl::StrJoin(error_summary_final, ", ") << "."; } auto ops = diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 8c79cb820015e16847ce48c171e8f6e41f60c319..098d2163e6c2fe26f3cb9cdf9959df62a1a4baf0 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -16,6 +16,8 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ #include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/contrib/lite/util.h" namespace toco { @@ -25,11 +27,18 @@ namespace tflite { // result in the given string. void Export(const Model& model, bool allow_custom_ops, string* output_file_contents); + // This if backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. inline void Export(const Model& model, string* output_file_contents) { Export(model, true, output_file_contents); } +// Export API with custom TFLite operator mapping. +void Export( + const Model& model, bool allow_custom_ops, string* output_file_contents, + const std::map>& ops_by_type); + namespace details { // A maps from tensor name to its final position in the TF Lite buffer. @@ -39,25 +48,35 @@ using TensorsMap = std::unordered_map; // Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to // identify which operation is used. struct OperatorKey { - OperatorKey(OperatorType type, const std::string& custom_code) - : type(type), custom_code(custom_code) {} + OperatorKey(OperatorType type, const std::string& custom_code, int version) + : type(type), custom_code(custom_code), version(version) {} const OperatorType type; const std::string custom_code; + const int version; bool operator<(const OperatorKey& other) const { if (type < other.type) return true; - if (type > other.type) return false; - return custom_code < other.custom_code; + else if (type > other.type) + return false; + else if (custom_code < other.custom_code) + return true; + else if (custom_code > other.custom_code) + return false; + else + return version < other.version; } bool operator==(const OperatorKey& other) const { - return type == other.type && custom_code == other.custom_code; + return type == other.type && custom_code == other.custom_code && + version == other.version; } struct Hash { - std::size_t operator()(const OperatorKey& key) const { - return std::hash()(static_cast(key.type)) ^ - std::hash()(key.custom_code); + size_t operator()(const OperatorKey& key) const { + return ::tflite::CombineHashes( + {std::hash()(static_cast(key.type)), + std::hash()(key.custom_code), + std::hash()(key.version)}); } }; }; @@ -66,11 +85,12 @@ struct OperatorKey { using OperatorsMap = std::unordered_map; void LoadTensorsMap(const Model& model, TensorsMap* tensors_map); -void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map); +void LoadOperatorsMap( + const Model& model, OperatorsMap* operators_map, + const std::map>& ops_by_type); } // namespace details } // namespace tflite - } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index 6754372330797ae30230af26a3b478c24ad44005..409e7d72a57076ec2832c5d12b52829477624f74 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -17,6 +17,9 @@ limitations under the License. #include #include #include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/contrib/lite/toco/tflite/types.h" namespace toco { namespace tflite { @@ -65,12 +68,13 @@ TEST_F(ExportTest, LoadOperatorsMap) { BuildTestModel(); details::OperatorsMap operators; - details::LoadOperatorsMap(input_model_, &operators); - EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "")]); - EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "")]); - EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "")]); + const auto ops_by_type = BuildOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]); + EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]); + EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]); EXPECT_EQ(3, operators[details::OperatorKey( - OperatorType::kTensorFlowUnsupported, "MyCrazyOp")]); + OperatorType::kTensorFlowUnsupported, "MyCrazyOp", 1)]); } TEST_F(ExportTest, Export) { @@ -104,6 +108,160 @@ TEST_F(ExportTest, Export) { EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2)); } +// This test is based on a hypothetical scenario that dilation is supported +// only in Conv version 2. So Toco populates version=1 when dialation +// parameters are all 1, and version=2 otehrwise. +class FakeConvolutionOperator + : public BuiltinOperator { + public: + FakeConvolutionOperator() + : BuiltinOperator(::tflite::BuiltinOperator_CONV_2D, + OperatorType::kConv) {} + + // Returning the op version according to the op parameters. + int GetVersion(const Operator& op) const override { + const TocoOperator& conv_op = static_cast(op); + if (conv_op.dilation_width_factor != 1 || + conv_op.dilation_height_factor != 1) { + // Version 2 if dilation is used. + return 2; + } + return 1; + } + + // Note: The read / write code doesn't need to be changed if we stick with + // the restrictions: + // * Only adding parameters at the bottom of the Flatbuffer tables. + // * When the default value of parameters are used, the op works consistently + // with the previous version. + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width, + op.stride_height, activation_function, + op.dilation_width_factor, + op.dilation_height_factor); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->dilation_width_factor = options.dilation_w_factor(); + op->dilation_height_factor = options.dilation_h_factor(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class VersionedOpExportTest : public ::testing::Test { + protected: + void SetUp() override { + input_model_.GetOrCreateArray("input"); + input_model_.GetOrCreateArray("filter"); + input_model_.GetOrCreateArray("output"); + } + void AddConvOp(bool use_dialation) { + { + auto* op = new ConvOperator; + op->inputs.push_back("input"); + op->inputs.push_back("filter"); + op->inputs.push_back("output"); + + op->padding.type = PaddingType::kSame; + op->stride_width = 1; + op->stride_height = 1; + if (use_dialation) { + op->dilation_width_factor = 2; + op->dilation_height_factor = 2; + } else { + op->dilation_width_factor = 1; + op->dilation_height_factor = 1; + } + input_model_.operators.emplace_back(op); + } + } + + std::map> + BuildFakeOperatorByTypeMap() { + std::map> result; + result[OperatorType::kConv] = + std::unique_ptr(new FakeConvolutionOperator); + return result; + } + + Model input_model_; +}; + +TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) { + AddConvOp(false); + + details::OperatorsMap operators; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + + EXPECT_EQ(1, operators.size()); + EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); +} + +TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) { + AddConvOp(true); + + details::OperatorsMap operators; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + + EXPECT_EQ(1, operators.size()); + EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2))); +} + +TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) { + AddConvOp(false); + AddConvOp(true); + + details::OperatorsMap operators; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + + EXPECT_EQ(2, operators.size()); + EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); + EXPECT_EQ(1, operators.at(details::OperatorKey(OperatorType::kConv, "", 2))); +} + +TEST_F(VersionedOpExportTest, Export) { + AddConvOp(false); + AddConvOp(true); + + string result; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + Export(input_model_, true, &result, ops_by_type); + + auto* model = ::tflite::GetModel(result.data()); + auto operator_codes = model->operator_codes(); + + // Verify that 2 operator codes are populdated. Both are CONV_2D but with + // different versions. + EXPECT_EQ(2, operator_codes->size()); + EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D, + (*operator_codes)[0]->builtin_code()); + EXPECT_EQ(1, (*operator_codes)[0]->version()); + EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D, + (*operator_codes)[1]->builtin_code()); + EXPECT_EQ(2, (*operator_codes)[1]->version()); + + // Verify that the 2 operators points to the correct indices of the operation + // codes. + auto operators = (*model->subgraphs())[0]->operators(); + EXPECT_EQ(2, operators->size()); + EXPECT_EQ(0, (*operators)[0]->opcode_index()); + EXPECT_EQ(1, (*operators)[1]->opcode_index()); +} + // TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators. } // namespace diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc index c0e7ab2ef57ed8edf1b7cda08c64f6ae66172af3..1be7cf07a7ffdec886dda94ac28218e233c39f3a 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -113,15 +113,34 @@ void ImportOperators( << operators_table.size(); } string opname = operators_table.at(index); + + // Find and use the appropriate operator deserialization factory. + std::unique_ptr new_op = nullptr; if (ops_by_name.count(opname) == 0) { - LOG(FATAL) << "Op '" << opname << "' not supported"; + string effective_opname = "TENSORFLOW_UNSUPPORTED"; + if (ops_by_name.count(effective_opname) == 0) { + LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found."; + } + new_op = ops_by_name.at(effective_opname) + ->Deserialize(input_op->builtin_options(), + input_op->custom_options()); + if (TensorFlowUnsupportedOperator* unsupported_op = + dynamic_cast(new_op.get())) { + unsupported_op->tensorflow_op = opname; + // TODO(b/109932940): Remove this when quantized is removed. + // For now, we assume all ops are quantized. + unsupported_op->quantized = true; + } else { + LOG(FATAL) << "Expected a TensorFlowUnsupportedOperator"; + } + } else { + new_op = ops_by_name.at(opname)->Deserialize(input_op->builtin_options(), + input_op->custom_options()); } - - auto new_op = ops_by_name.at(opname)->Deserialize( - input_op->builtin_options(), input_op->custom_options()); model->operators.emplace_back(new_op.release()); auto* op = model->operators.back().get(); + // Make sure all the inputs and outputs are hooked up. auto inputs = input_op->inputs(); for (int i = 0; i < inputs->Length(); i++) { auto input_index = inputs->Get(i); diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 90e24aa104f5b0b52bb34c1974bdbd92fe37a3f5..7490ab960b9b0c62bef4c343927664ac6ae4eb9d 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -53,6 +53,8 @@ class AveragePool op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Convolution @@ -83,6 +85,8 @@ class Convolution op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class DepthwiseConvolution @@ -112,6 +116,8 @@ class DepthwiseConvolution op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Add : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class SpaceToBatchND @@ -149,6 +157,8 @@ class SpaceToBatchND void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } }; class Sub : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Div : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class BatchToSpaceND @@ -206,6 +220,8 @@ class BatchToSpaceND void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } }; class Cast : public BuiltinOperatorsrc_data_type = DataType::Deserialize(options.in_data_type()); op->dst_data_type = DataType::Deserialize(options.out_data_type()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Concatenation @@ -243,6 +261,8 @@ class Concatenation TocoOperator* op) const override { op->axis = options.axis(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class DepthToSpace : public CustomOperator { @@ -255,6 +275,8 @@ class DepthToSpace : public CustomOperator { void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { op->block_size = m["block_size"].AsInt64(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class FakeQuant : public CustomOperator { @@ -274,6 +296,8 @@ class FakeQuant : public CustomOperator { const auto& num_bits = m["num_bits"]; op->num_bits = num_bits.IsInt() ? num_bits.AsInt32() : 8; } + + int GetVersion(const Operator& op) const override { return 1; } }; class FullyConnected @@ -295,6 +319,8 @@ class FullyConnected op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Gather : public BuiltinOperatoraxis = options.axis(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Svdf : public BuiltinOperatorrank = options.rank(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class L2Normalization @@ -351,6 +381,8 @@ class L2Normalization op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class L2Pool : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class LocalResponseNormalization @@ -401,6 +435,8 @@ class LocalResponseNormalization op->alpha = options.alpha(); op->beta = options.beta(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class MaxPool : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Mul : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Pad : public BuiltinOperator { + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateTileOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + int GetVersion(const Operator& op) const override { return 1; } }; class PadV2 : public BuiltinOperatorshape.insert(op->shape.end(), options.new_shape()->begin(), options.new_shape()->end()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Softmax @@ -516,6 +578,8 @@ class Softmax TocoOperator* op) const override { op->beta = options.beta(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class SpaceToDepth @@ -534,6 +598,8 @@ class SpaceToDepth TocoOperator* op) const override { op->block_size = options.block_size(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Transpose @@ -549,6 +615,8 @@ class Transpose void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } }; class Lstm : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { + ::tflite::LSTMKernelType kernel_type; + switch (op.kernel_type) { + case LstmCellOperator::KERNEL_BASIC: + kernel_type = ::tflite::LSTMKernelType_BASIC; + break; + case LstmCellOperator::KERNEL_FULL: + kernel_type = ::tflite::LSTMKernelType_FULL; + break; + } + // Current toco converter only supports tanh, no clip. return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/ ::tflite::ActivationFunctionType_TANH, /*cell_clip=*/0.0, - /*proj_clip=*/0.0); + /*proj_clip=*/0.0, kernel_type); } void ReadOptions(const TfLiteOptions& options, @@ -570,6 +648,25 @@ class Lstm : public BuiltinOperatorkernel_type = LstmCellOperator::KERNEL_BASIC; + break; + case ::tflite::LSTMKernelType_FULL: + op->kernel_type = LstmCellOperator::KERNEL_FULL; + break; + } + } + + int GetVersion(const Operator& op) const override { + const auto& lstm_op = static_cast(op); + switch (lstm_op.kernel_type) { + case LstmCellOperator::KERNEL_FULL: + return 1; + case LstmCellOperator::KERNEL_BASIC: + return 2; + } } }; @@ -587,6 +684,8 @@ class Mean : public BuiltinOperatorkeep_dims = options.keep_dims(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class ResizeBilinear @@ -605,6 +704,8 @@ class ResizeBilinear TocoOperator* op) const override { op->align_corners = options.align_corners(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Squeeze @@ -626,6 +727,8 @@ class Squeeze options.squeeze_dims()->begin(), options.squeeze_dims()->end()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Split @@ -644,6 +747,8 @@ class Split TocoOperator* op) const override { op->num_split = options.num_splits(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class StridedSlice @@ -668,6 +773,8 @@ class StridedSlice op->new_axis_mask = options.new_axis_mask(); op->shrink_axis_mask = options.shrink_axis_mask(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class TopK_V2 : public BuiltinOperatoroutput_data_type = DataType::Deserialize(options.output_type()); } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class TransposeConv + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + return ::tflite::CreateTransposeConvOptions( + *builder, padding, op.stride_width, op.stride_height); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class SparseToDense + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->validate_indices = options.validate_indices(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class ExpandDims + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateExpandDimsOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } }; class TensorFlowUnsupported : public BaseOperator { @@ -805,6 +980,12 @@ class TensorFlowUnsupported : public BaseOperator { } node_def.SerializeToString(&op->tensorflow_node_def); } + + int GetVersion(const Operator& op) const override { + // TODO(ycling): Deisng and implement a way to plumb the version of + // custom ops. + return 1; + } }; namespace { @@ -877,6 +1058,14 @@ std::vector> BuildOperatorList() { new Cast(::tflite::BuiltinOperator_CAST, OperatorType::kCast)); ops.emplace_back( new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); + ops.emplace_back( + new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTensorFlowTile)); + ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS, + OperatorType::kExpandDims)); + ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV, + OperatorType::kTransposeConv)); + ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE, + OperatorType::kSparseToDense)); // Custom Operators. ops.emplace_back( @@ -923,9 +1112,18 @@ std::vector> BuildOperatorList() { "LESS", OperatorType::kTensorFlowLess)); ops.emplace_back(new SimpleOperator( "LESS_EQUAL", OperatorType::kTensorFlowLessEqual)); + ops.emplace_back(new SimpleOperator( + "EQUAL", OperatorType::kTensorFlowEqual)); + ops.emplace_back(new SimpleOperator( + "NOT_EQUAL", OperatorType::kTensorFlowNotEqual)); ops.emplace_back(new SimpleOperator("NEG", OperatorType::kNeg)); ops.emplace_back( new SimpleOperator("SELECT", OperatorType::kSelect)); + ops.emplace_back( + new SimpleOperator("SLICE", OperatorType::kSlice)); + // Element-wise operator + ops.emplace_back(new SimpleOperator("SIN", OperatorType::kSin)); + ops.emplace_back(new SimpleOperator("LOG", OperatorType::kLog)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index 85f7bdafe04979abc14f826ef667b3fa1aeec65c..5e9c20e40dd6274e0839379883b6dbe53064a0fc 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -77,6 +77,16 @@ class BaseOperator { const BuiltinOptions* builtin_options, const CustomOptions* custom_options) const = 0; + // Get the op version by op parameters. + // The function need to be overridden to return the op version based on the + // parameters. Note: + // * The first version for each op should be 1 (to be consistent with the + // default value in Flatbuffer. `return 1;` is okay for newly implemented + // ops. + // * When multiple versions are defined for an op, this function need to be + // overridden. (See example in `operator_test.cc`) + virtual int GetVersion(const Operator& op) const = 0; + private: string name_; OperatorType type_; diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index a4fff9974a6421b472a30b697e9c05fc404c9a01..e3144ad63e9f20e34ab0f7e217b4095b08eec9af 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -117,6 +117,13 @@ TEST_F(OperatorTest, SimpleOperators) { OperatorType::kTensorFlowLess); CheckSimpleOperator("NEG", OperatorType::kNeg); CheckSimpleOperator("SELECT", OperatorType::kSelect); + CheckSimpleOperator("SLICE", OperatorType::kSlice); + CheckSimpleOperator("SIN", OperatorType::kSin); + CheckSimpleOperator("EQUAL", + OperatorType::kTensorFlowEqual); + CheckSimpleOperator( + "NOT_EQUAL", OperatorType::kTensorFlowNotEqual); + CheckSimpleOperator("LOG", OperatorType::kLog); } TEST_F(OperatorTest, BuiltinAdd) { @@ -406,6 +413,27 @@ TEST_F(OperatorTest, BuiltinArgMax) { EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); } +TEST_F(OperatorTest, BuiltinTransposeConv) { + TransposeConvOperator op; + op.stride_width = 123; + op.stride_height = 124; + op.padding.type = PaddingType::kValid; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("TRANSPOSE_CONV", OperatorType::kTransposeConv), op); + EXPECT_EQ(op.stride_width, output_toco_op->stride_width); + EXPECT_EQ(op.stride_height, output_toco_op->stride_height); + EXPECT_EQ(op.padding.type, output_toco_op->padding.type); +} + +TEST_F(OperatorTest, BuiltinSparseToDense) { + SparseToDenseOperator op; + op.validate_indices = false; + std::unique_ptr output_toco_op = + SerializeAndDeserialize( + GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op); + EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; diff --git a/tensorflow/contrib/lite/toco/tflite/simple_operator.h b/tensorflow/contrib/lite/toco/tflite/simple_operator.h index 72678c82a22a7168f858747b0b1c6a2b515b6578..a7f7e886f61d3bbf221c0ab7a24d6c3e629ec274 100644 --- a/tensorflow/contrib/lite/toco/tflite/simple_operator.h +++ b/tensorflow/contrib/lite/toco/tflite/simple_operator.h @@ -41,6 +41,8 @@ class SimpleOperator : public BaseOperator { const CustomOptions* custom_options) const override { return std::unique_ptr(new T); } + + int GetVersion(const Operator& op) const override { return 1; } }; } // namespace tflite diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc index c9c2e9ba0184ef3f531f325091afaf6976e07f4f..4867c3a62e68406428644cd05bddf212008c2656 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.cc +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -36,6 +36,16 @@ DataBuffer::FlatBufferOffset CopyStringToBuffer( return builder->CreateVector(dst_data.data(), bytes); } +// vector may be implemented using a bit-set, so we can't just +// reinterpret_cast, accesing it data as vector and let flatbuffer +// CreateVector handle it. +// Background: https://isocpp.org/blog/2012/11/on-vectorbool +DataBuffer::FlatBufferOffset CopyBoolToBuffer( + const Array& array, flatbuffers::FlatBufferBuilder* builder) { + const auto& src_data = array.GetBuffer().data; + return builder->CreateVector(src_data); +} + template DataBuffer::FlatBufferOffset CopyBuffer( const Array& array, flatbuffers::FlatBufferBuilder* builder) { @@ -86,6 +96,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { return ::tflite::TensorType_UINT8; case ArrayDataType::kString: return ::tflite::TensorType_STRING; + case ArrayDataType::kBool: + return ::tflite::TensorType_BOOL; default: // FLOAT32 is filled for unknown data types. // TODO(ycling): Implement type inference in TF Lite interpreter. @@ -105,6 +117,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { return ArrayDataType::kString; case ::tflite::TensorType_UINT8: return ArrayDataType::kUint8; + case ::tflite::TensorType_BOOL: + return ArrayDataType::kBool; default: LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'."; } @@ -125,6 +139,8 @@ flatbuffers::Offset> DataBuffer::Serialize( return CopyStringToBuffer(array, builder); case ArrayDataType::kUint8: return CopyBuffer(array, builder); + case ArrayDataType::kBool: + return CopyBoolToBuffer(array, builder); default: LOG(FATAL) << "Unhandled array data type."; } @@ -146,6 +162,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, return CopyStringFromBuffer(buffer, array); case ::tflite::TensorType_UINT8: return CopyBuffer(buffer, array); + case ::tflite::TensorType_BOOL: + return CopyBuffer(buffer, array); default: LOG(FATAL) << "Unhandled tensor type."; } diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc index efb849f42283de5867c3b1ad914655212a13edb5..564f303b9bb41a777633ecabd666aa93ec3faefe 100644 --- a/tensorflow/contrib/lite/toco/tflite/types_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc @@ -28,8 +28,7 @@ using flatbuffers::Vector; // These are types that exist in TF Mini but don't have a correspondence // in TF Lite. -static const ArrayDataType kUnsupportedTocoTypes[] = {ArrayDataType::kNone, - ArrayDataType::kBool}; +static const ArrayDataType kUnsupportedTocoTypes[] = {ArrayDataType::kNone}; // These are TF Lite types for which there is no correspondence in TF Mini. static const ::tflite::TensorType kUnsupportedTfLiteTypes[] = { @@ -71,7 +70,8 @@ TEST(DataType, SupportedTypes) { {ArrayDataType::kUint8, ::tflite::TensorType_UINT8}, {ArrayDataType::kInt32, ::tflite::TensorType_INT32}, {ArrayDataType::kInt64, ::tflite::TensorType_INT64}, - {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}}; + {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}, + {ArrayDataType::kBool, ::tflite::TensorType_BOOL}}; for (auto x : testdata) { EXPECT_EQ(x.second, DataType::Serialize(x.first)); EXPECT_EQ(x.first, DataType::Deserialize(x.second)); @@ -158,6 +158,13 @@ TEST(DataBuffer, String) { ::testing::ElementsAre("AA", "BBB", "Best. String. Ever.")); } +TEST(DataBuffer, Bool) { + Array recovered = + ToFlatBufferAndBack({true, false, true}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(true, false, true)); +} + TEST(Padding, All) { EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame)); EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME)); diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index 7786a4ada335abc9a01a0a6e423125f2d67957c2..87a1e429b928bf59cb14597980602953732a7659 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -153,6 +153,16 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.dedupe_array_min_size_bytes.default_value(), "Minimum size of constant arrays to deduplicate; arrays smaller " "will not be deduplicated."), + Flag("split_tflite_lstm_inputs", + parsed_flags.split_tflite_lstm_inputs.bind(), + parsed_flags.split_tflite_lstm_inputs.default_value(), + "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. " + "Ignored if the output format is not TFLite."), + Flag("quantize_weights", parsed_flags.quantize_weights.bind(), + parsed_flags.quantize_weights.default_value(), + "Store weights as quantized weights followed by dequantize " + "operations. Computation is still done in float, but reduces model " + "size (at the cost of accuracy and latency)."), }; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); @@ -245,6 +255,8 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel, FlagRequirement::kNone); READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone); + READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone); + READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone); // Deprecated flag handling. if (parsed_toco_flags.input_type.specified()) { @@ -278,6 +290,11 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, QCHECK(toco::IODataType_Parse(input_types[0], &input_type)); toco_flags->set_inference_input_type(input_type); } + if (parsed_toco_flags.quantize_weights.value()) { + QCHECK_NE(toco_flags->inference_type(), IODataType::QUANTIZED_UINT8) + << "quantize_weights is not supported with inference_type " + "QUANTIZED_UINT8."; + } #undef READ_TOCO_FLAG #undef PARSE_TOCO_FLAG diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index 8589ca361dae2561207f9fa0c57b3240240c08d6..ad4e94ded9f9730842a257e065d9aec2b1cbfac8 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 19. +// Next ID to use: 21. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -165,4 +165,22 @@ message TocoFlags { // Minimum size of constant arrays to deduplicate; arrays smaller will not be // deduplicated. optional int64 dedupe_array_min_size_bytes = 18 [default = 64]; + + // Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. + // Ignored if the output format is not TFLite. + optional bool split_tflite_lstm_inputs = 19 [default = true]; + + // Store weights as quantized weights followed by dequantize operations. + // Computation is still done in float, but reduces model size (at the cost of + // accuracy and latency). + optional bool quantize_weights = 20 [default = false]; + + // Full filepath of folder to dump the graphs at various stages of processing + // GraphViz .dot files. Preferred over --output_format=GRAPHVIZ_DOT in order + // to keep the requirements of the output file. + optional string dump_graphviz_dir = 24; + + // Boolean indicating whether to dump the graph after every graph + // transformation. + optional bool dump_graphviz_include_video = 25; } diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h index d6c3ba6543378b3e15b5fb7816f52376fe05123d..7cdd55e5422589aa000000b82d09b9d8397d7a88 100644 --- a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h +++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h @@ -21,8 +21,6 @@ namespace toco { // Global data for determining whether to output graph viz format from toco. struct GraphVizDumpOptions { - std::string graphviz_first_array; - std::string graphviz_last_array; std::string dump_graphviz; bool dump_graphviz_video = false; diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc index a1c8696cd06a30bfe8661bb70aa4f2d6d175aac3..3a5911c28dc5462b5d3747f6af6aa82026a23466 100644 --- a/tensorflow/contrib/lite/toco/toco_port.cc +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -18,6 +18,12 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_types.h" #include "tensorflow/core/platform/logging.h" +#if defined(__ANDROID__) && defined(__ARM_ARCH_7A__) +namespace std { +double round(double x) { return ::round(x); } +} // namespace std +#endif + namespace toco { namespace port { void CopyToBuffer(const string& src, char* dest) { diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h index 906792ef569e5b8dd2a40f6cf683fa8a35946012..b00b1e89e856190787d2d40096c9a5321bd80604 100644 --- a/tensorflow/contrib/lite/toco/toco_port.h +++ b/tensorflow/contrib/lite/toco/toco_port.h @@ -33,6 +33,24 @@ limitations under the License. #define TFLITE_PROTO_NS google::protobuf #endif +#ifdef __ANDROID__ +#include +namespace std { + +template +std::string to_string(T value) +{ + std::ostringstream os ; + os << value ; + return os.str() ; +} + +#ifdef __ARM_ARCH_7A__ +double round(double x); +#endif +} +#endif + namespace toco { namespace port { diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 58c99051bd9424ff0b2b446bd50b4f1c158d22c6..1fe76f8163cdf23b27f8baaf2d9c6d99b1aa3747 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -86,6 +86,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveConstantRandomUniform); transformations->Add(new ResolveConstantRange); transformations->Add(new ResolveConstantReshape); + transformations->Add(new ResolveConstantSlice); transformations->Add(new ResolveConstantStack); transformations->Add(new ResolveConstantStridedSlice); transformations->Add(new ResolveConstantTranspose); @@ -262,12 +263,15 @@ void Transform(const TocoFlags& toco_flags, Model* model) { if (!toco_flags.debug_disable_recurrent_cell_fusion()) { transformations.Add(new IdentifyLstmCell); } - if (output_format == TFLITE) { + if (output_format == TFLITE && toco_flags.split_tflite_lstm_inputs()) { transformations.Add(new toco::SplitLstmCellInputs); } else { transformations.Add(new toco::MergeLstmCellInputs); } } + if (toco_flags.quantize_weights()) { + transformations.Add(new QuantizeWeights); + } transformations.Add(new ResolveConstantConcatenation); RunGraphTransformations(model, "general graph transformations", transformations); @@ -372,6 +376,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) { LOG(INFO) << "Estimated count of arithmetic ops: " << 1e-9 * ops_count << " billion (note that a multiply-add is counted as 2 ops)."; } + model->ops_count = ops_count; } void Export(const TocoFlags& toco_flags, const Model& model, diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 1f56fe5c833addb3aa56d42421b2744627167ccc..13e9331919c6e75317388b2e42f2f4d352f67344 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -337,6 +337,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(LogSoftmax) HANDLE_OPERATORTYPENAME_CASE(Div) HANDLE_OPERATORTYPENAME_CASE(Tanh) + HANDLE_OPERATORTYPENAME_CASE(Sin) HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll) HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert) HANDLE_OPERATORTYPENAME_CASE(ExpandDims) @@ -392,6 +393,9 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) HANDLE_OPERATORTYPENAME_CASE(Select) + HANDLE_OPERATORTYPENAME_CASE(SparseToDense) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowEqual) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowNotEqual) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -916,7 +920,7 @@ void CheckEachArray(const Model& model) { CHECK(array->buffer->type == array->data_type); // The presence of a fixed buffer should imply the presence of a fixed // shape. - CHECK(array->has_shape()); + CHECK(array->has_shape()) << "Invalid array: " << array_entry.first; // Constant buffer should has a valid shape. for (int d : array->shape().dims()) { CHECK_GE(d, 1); @@ -986,7 +990,7 @@ void FixOperatorOrdering(Model* model) { for (auto i : remaining) { bool can_insert = true; auto& op = old_operators[i]; - CHECK(op.get()); + CHECK(op); for (const auto& input : op->inputs) { if (!IsConstantParameterArray(*model, input) && !arrays_behind_us.count(input)) { @@ -1861,18 +1865,15 @@ void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, output_axes_order == AxesOrder::kHWIO) { // 3210 <- 3210 // HWIO <- OHWI - (*shuffle)[0] = 1; - (*shuffle)[1] = 2; - (*shuffle)[2] = 3; - (*shuffle)[3] = 0; + *shuffle = {1, 2, 3, 0}; } else if (input_axes_order == AxesOrder::kHWIO && output_axes_order == AxesOrder::kOHWI) { // 3210 <- 3210 // OHWI <- HWIO - (*shuffle)[0] = 3; - (*shuffle)[1] = 0; - (*shuffle)[2] = 1; - (*shuffle)[3] = 2; + *shuffle = {3, 0, 1, 2}; + } else if (input_axes_order == AxesOrder::kOHWI && + output_axes_order == AxesOrder::kHWOI) { + *shuffle = {1, 2, 0, 3}; } else { LOG(FATAL) << "Bad shuffle"; } @@ -2018,6 +2019,8 @@ int AxesCount(AxesOrder axes_order) { return 4; case AxesOrder::kNHWC: return 4; + case AxesOrder::kHWOI: + return 4; default: LOG(FATAL) << "Bad AxesOrder"; return 0; @@ -2073,15 +2076,21 @@ bool ReshapeIsEquivalentToTranspose(const Model& model, void CheckFinalDataTypesSatisfied(const Model& model) { for (const auto& array_entry : model.GetArrayMap()) { const auto& array = *array_entry.second; + if (array.data_type == ArrayDataType::kBool) { + // Boolean values are never quantized. + continue; + } + // If the final data type is int16, the data type may be float, for example // after dequantization. if (array.final_data_type != ArrayDataType::kNone && array.final_data_type != ArrayDataType::kInt16) { - CHECK(array.final_data_type == array.data_type) + CHECK(array.data_type == array.final_data_type) << "Array \"" << array_entry.first - << "\" has mis-matching actual and final data types (" - << ArrayDataTypeName(array.data_type) << "," - << ArrayDataTypeName(array.final_data_type) << ")."; + << "\" has mis-matching actual and final data types (data_type=" + << ArrayDataTypeName(array.data_type) + << ", final_data_type=" << ArrayDataTypeName(array.final_data_type) + << ")."; } } } diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 1f596ca8e5a28f17e816c33eea03725d16f7ce12..3b320e801349595396e573e225ffacf4c7607e52 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -26,7 +26,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/core/platform/logging.h" #if TOCO_SUPPORT_PORTABLE_PROTOS -#include "third_party/protobuf/src/google/protobuf/text_format.h" +#include "third_party/protobuf/include/google/protobuf/text_format.h" #endif // TOCO_SUPPORT_PORTABLE_PROTOS #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 7b3569ea9c8b15959b15e8ba46cf44d159d5528c..5913847329eeae7373d0d21834dd37327e4068c4 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -7,6 +7,8 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +common_copts = ["-Wall"] + py_binary( name = "visualize", srcs = ["visualize.py"], @@ -28,35 +30,6 @@ tf_cc_binary( ], ) -tf_cc_binary( - name = "benchmark_model", - srcs = ["benchmark_model.cc"], - linkopts = select({ - "//tensorflow:android": [ - "-pie", - "-landroid", - "-lm", - "-z defs", - "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export - ], - "//conditions:default": [], - }), - deps = [ - ":mutable_op_resolver", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib", - ], - "//conditions:default": [ - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], - }), -) - cc_library( name = "gen_op_registration", srcs = ["gen_op_registration.cc"], @@ -87,13 +60,6 @@ cc_test( ], ) -cc_library( - name = "mutable_op_resolver", - srcs = ["mutable_op_resolver.cc"], - hdrs = ["mutable_op_resolver.h"], - deps = ["//tensorflow/contrib/lite:framework"], -) - cc_library( name = "verifier", srcs = ["verifier.cc"], @@ -103,7 +69,6 @@ cc_library( "//tensorflow/contrib/lite:schema_fbs_version", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/schema:schema_fbs", - "@com_google_absl//absl/base:core_headers", ], ) @@ -115,11 +80,9 @@ cc_test( "tflite_not_portable", ], deps = [ - ":mutable_op_resolver", ":verifier", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/contrib/lite/testing:util", "//tensorflow/core:framework_lite", diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..96c6b6872e3d6e44c9d1a8f642b135a664dd1eca --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/BUILD @@ -0,0 +1,91 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +common_copts = ["-Wall"] + tflite_copts() + +cc_binary( + name = "benchmark_model", + srcs = [ + "benchmark_main.cc", + "logging.h", + ], + copts = common_copts, + linkopts = tflite_linkopts() + select({ + "//tensorflow:android": [ + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm + ], + "//conditions:default": [], + }), + deps = [ + ":benchmark_tflite_model_lib", + ], +) + +cc_library( + name = "command_line_flags", + srcs = ["command_line_flags.cc"], + hdrs = ["command_line_flags.h"], + copts = common_copts, + visibility = ["//visibility:private"], +) + +cc_test( + name = "command_line_flags_test", + srcs = ["command_line_flags_test.cc"], + copts = common_copts, + visibility = ["//visibility:private"], + deps = [ + ":command_line_flags", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "benchmark_tflite_model_lib", + srcs = [ + "benchmark_tflite_model.cc", + "logging.h", + ], + hdrs = ["benchmark_tflite_model.h"], + copts = common_copts, + linkopts = tflite_linkopts(), + deps = [ + ":benchmark_model_lib", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/profiling:profile_summarizer", + "//tensorflow/contrib/lite/profiling:profiler", + ], +) + +cc_library( + name = "benchmark_model_lib", + srcs = [ + "benchmark_model.cc", + "logging.h", + ], + hdrs = ["benchmark_model.h"], + copts = common_copts, + deps = [ + ":command_line_flags", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/profiling:profile_summarizer", + "//tensorflow/contrib/lite/profiling:profiler", + "//tensorflow/contrib/lite/profiling:time", + "//tensorflow/core:stats_calculator_portable", + ], +) + +tflite_portable_test_suite() diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c10826afff6d5569545d4b7df73c88d24d9dcd1a --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/README.md @@ -0,0 +1,154 @@ +# TFLite Model Benchmark Tool + +## Description + +A simple C++ binary to benchmark a TFLite model and its individual operators, +both on desktop machines and on Android. + +## To build/install/run + +### On Android: + +(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android to edit the `WORKSPACE` to configure the android NDK/SDK. + +(1) Build for your specific platform, e.g.: + +``` +bazel build -c opt \ + --config=android_arm \ + --cxxopt='--std=c++11' \ + tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` + +(2) Connect your phone. Push the binary to your phone with adb push + (make the directory if required): + +``` +adb push bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model /data/local/tmp +``` + +(3) Make the binary executable. + +``` +adb shell chmod +x /data/local/tmp/benchmark_model +``` + +(4) Push the compute graph that you need to test. For example: + +``` +adb push mobilenet_quant_v1_224.tflite /data/local/tmp +``` + +(5) Run the benchmark. For example: + +``` +adb shell /data/local/tmp/benchmark_model \ + --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ + --input_layer="Placeholder" \ + --input_layer_shape="1,224,224,3" \ + --num_threads=4 +``` + +### On desktop: +(1) build the binary + +``` +bazel build -c opt tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` + +(2) Run on your compute graph, similar to the Android case but without the need of adb shell. +For example: + +``` +bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \ + --graph=mobilenet_quant_v1_224.tflite \ + --input_layer="Placeholder" \ + --input_layer_shape="1,224,224,3" \ + --num_threads=4 +``` + +The MobileNet graph used as an example here may be downloaded from +https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip + +## Profiling model operators +The benchmark model binary also allows you to profile operators and give execution times of each operator. To do this, +compile the binary with a compiler flag that enables profiling to be compiled in. Pass **--copt=-DTFLITE_PROFILING_ENABLED** +to compile benchmark with profiling support. +For example, to compile with profiling support on Android, add this flag to the previous command: + +``` +bazel build -c opt \ + --config=android_arm \ + --cxxopt='--std=c++11' \ + --copt=-DTFLITE_PROFILING_ENABLED \ + tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` +This compiles TFLite with profiling enabled, now you can run the benchmark binary like before. The binary will produce detailed statistics for each operation similar to those shown below: + +``` + +============================== Run Order ============================== + [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] + CONV_2D 0.000 4.269 4.269 0.107% 0.107% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] + DEPTHWISE_CONV_2D 4.270 2.150 2.150 0.054% 0.161% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6] + CONV_2D 6.421 6.107 6.107 0.153% 0.314% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + DEPTHWISE_CONV_2D 12.528 1.366 1.366 0.034% 0.348% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6] + CONV_2D 13.895 4.195 4.195 0.105% 0.454% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6] + DEPTHWISE_CONV_2D 18.091 1.260 1.260 0.032% 0.485% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6] + CONV_2D 19.352 6.652 6.652 0.167% 0.652% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + DEPTHWISE_CONV_2D 26.005 0.698 0.698 0.018% 0.670% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6] + CONV_2D 26.703 3.344 3.344 0.084% 0.754% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6] + DEPTHWISE_CONV_2D 30.047 0.646 0.646 0.016% 0.770% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6] + CONV_2D 30.694 5.800 5.800 0.145% 0.915% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + DEPTHWISE_CONV_2D 36.495 0.331 0.331 0.008% 0.924% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6] + CONV_2D 36.826 2.838 2.838 0.071% 0.995% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] + DEPTHWISE_CONV_2D 39.665 0.439 0.439 0.011% 1.006% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6] + CONV_2D 40.105 5.293 5.293 0.133% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + DEPTHWISE_CONV_2D 45.399 0.352 0.352 0.009% 1.147% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6] + CONV_2D 45.752 5.322 5.322 0.133% 1.281% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + DEPTHWISE_CONV_2D 51.075 0.357 0.357 0.009% 1.290% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6] + CONV_2D 51.432 5.693 5.693 0.143% 1.433% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + DEPTHWISE_CONV_2D 57.126 0.366 0.366 0.009% 1.442% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6] + CONV_2D 57.493 5.472 5.472 0.137% 1.579% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + DEPTHWISE_CONV_2D 62.966 0.364 0.364 0.009% 1.588% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6] + CONV_2D 63.330 5.404 5.404 0.136% 1.724% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] + DEPTHWISE_CONV_2D 68.735 0.155 0.155 0.004% 1.728% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6] + CONV_2D 68.891 2.970 2.970 0.074% 1.802% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6] + DEPTHWISE_CONV_2D 71.862 0.206 0.206 0.005% 1.807% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6] + CONV_2D 72.069 5.888 5.888 0.148% 1.955% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + AVERAGE_POOL_2D 77.958 0.036 0.036 0.001% 1.956% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool] + CONV_2D 77.994 1.445 1.445 0.036% 1.992% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd] + RESHAPE 79.440 0.002 0.002 0.000% 1.992% 0.000 0 [MobilenetV1/Predictions/Reshape] + SOFTMAX 79.443 0.029 0.029 0.001% 1.993% 0.000 0 [MobilenetV1/Predictions/Softmax] + +============================== Top by Computation Time ============================== + [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] + CONV_2D 19.352 6.652 6.652 0.167% 0.167% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + CONV_2D 6.421 6.107 6.107 0.153% 0.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + CONV_2D 72.069 5.888 5.888 0.148% 0.468% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + CONV_2D 30.694 5.800 5.800 0.145% 0.613% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + CONV_2D 51.432 5.693 5.693 0.143% 0.756% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + CONV_2D 57.493 5.472 5.472 0.137% 0.893% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + CONV_2D 63.330 5.404 5.404 0.136% 1.029% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] + CONV_2D 45.752 5.322 5.322 0.133% 1.162% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + CONV_2D 40.105 5.293 5.293 0.133% 1.295% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + CONV_2D 0.000 4.269 4.269 0.107% 1.402% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] + +Number of nodes executed: 31 +============================== Summary by node type ============================== + [Node type] [count] [avg ms] [avg %] [cdf %] [mem KB] [times called] + CONV_2D 15 1.406 89.270% 89.270% 0.000 0 + DEPTHWISE_CONV_2D 13 0.169 10.730% 100.000% 0.000 0 + SOFTMAX 1 0.000 0.000% 100.000% 0.000 0 + RESHAPE 1 0.000 0.000% 100.000% 0.000 0 + AVERAGE_POOL_2D 1 0.000 0.000% 100.000% 0.000 0 + +Timings (microseconds): count=50 first=79449 curr=81350 min=77385 max=88213 avg=79732 std=1929 +Memory (bytes): count=0 +31 nodes observed + + +Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9 +``` + + diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..372d31e838e5666df492ee3156022249a2d97691 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace tflite { +namespace benchmark { + +int Main(int argc, char** argv) { +#ifdef TFLITE_CUSTOM_OPS_HEADER + TFLITE_LOG(INFO) << "STARTING with custom ops!"; +#else + TFLITE_LOG(INFO) << "STARTING!"; +#endif + BenchmarkTfLiteModel benchmark; + BenchmarkLoggingListener listener; + benchmark.AddListener(&listener); + benchmark.Run(argc, argv); + return 0; +} +} // namespace benchmark +} // namespace tflite + +int main(int argc, char** argv) { return tflite::benchmark::Main(argc, argv); } diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8a9a6112c1ec050be8d0bcfe9dc5f00df40d3ff --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" + +#include + +#include +#include + +#include "tensorflow/contrib/lite/profiling/time.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace { +void SleepForSeconds(double sleep_seconds) { + if (sleep_seconds <= 0.0) { + return; + } + // Convert the run_delay string into a timespec. + timespec req; + req.tv_sec = static_cast(sleep_seconds); + req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; + // If requested, sleep between runs for an arbitrary amount of time. + // This can be helpful to determine the effect of mobile processor + // scaling and thermal throttling. +#ifdef PLATFORM_WINDOWS + Sleep(sleep_seconds * 1000); +#else + nanosleep(&req, nullptr); +#endif +} + +} // namespace + +namespace tflite { +namespace benchmark { +using tensorflow::Stat; + +void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) { + auto inference_us = results.inference_time_us(); + auto init_us = results.startup_latency_us(); + auto warmup_us = results.warmup_time_us(); + TFLITE_LOG(INFO) << "Average inference timings in us: " + << "Warmup: " << warmup_us.avg() << ", " + << "Init: " << init_us << ", " + << "no stats: " << inference_us.avg(); +} + +std::vector BenchmarkModel::GetFlags() { + return { + Flag("num_runs", ¶ms_.num_runs, "number of runs"), + Flag("run_delay", ¶ms_.run_delay, "delay between runs in seconds"), + Flag("num_threads", ¶ms_.num_threads, "number of threads"), + Flag("benchmark_name", ¶ms_.benchmark_name, "benchmark name"), + Flag("output_prefix", ¶ms_.output_prefix, "benchmark output prefix"), + Flag("warmup_runs", ¶ms_.warmup_runs, + "how many runs to initialize model"), + }; +} + +void BenchmarkModel::LogFlags() { + TFLITE_LOG(INFO) << "Num runs: [" << params_.num_runs << "]"; + TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" << params_.run_delay + << "]"; + TFLITE_LOG(INFO) << "Num threads: [" << params_.num_threads << "]"; + TFLITE_LOG(INFO) << "Benchmark name: [" << params_.benchmark_name << "]"; + TFLITE_LOG(INFO) << "Output prefix: [" << params_.output_prefix << "]"; + TFLITE_LOG(INFO) << "Warmup runs: [" << params_.warmup_runs << "]"; +} + +Stat BenchmarkModel::Run(int num_times, RunType run_type) { + Stat run_stats; + TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations "; + for (int run = 0; run < num_times; run++) { + listeners_.OnSingleRunStart(run_type); + int64_t start_us = profiling::time::NowMicros(); + RunImpl(); + int64_t end_us = profiling::time::NowMicros(); + listeners_.OnSingleRunEnd(); + + run_stats.UpdateStat(end_us - start_us); + SleepForSeconds(params_.run_delay); + } + + std::stringstream stream; + run_stats.OutputToStream(&stream); + TFLITE_LOG(INFO) << stream.str() << std::endl; + + return run_stats; +} + +void BenchmarkModel::Run(int argc, char **argv) { + if (!ParseFlags(argc, argv)) { + return; + } + + LogFlags(); + + listeners_.OnBenchmarkStart(params_); + int64_t initialization_start_us = profiling::time::NowMicros(); + Init(); + int64_t initialization_end_us = profiling::time::NowMicros(); + int64_t startup_latency_us = initialization_end_us - initialization_start_us; + TFLITE_LOG(INFO) << "Initialized session in " << startup_latency_us / 1e3 + << "ms"; + + uint64_t input_bytes = ComputeInputBytes(); + Stat warmup_time_us = Run(params_.warmup_runs, WARMUP); + Stat inference_time_us = Run(params_.num_runs, REGULAR); + listeners_.OnBenchmarkEnd( + {startup_latency_us, input_bytes, warmup_time_us, inference_time_us}); +} + +bool BenchmarkModel::ParseFlags(int argc, char **argv) { + auto flag_list = GetFlags(); + const bool parse_result = + Flags::Parse(&argc, const_cast(argv), flag_list); + if (!parse_result) { + std::string usage = Flags::Usage(argv[0], flag_list); + TFLITE_LOG(ERROR) << usage; + return false; + } + return ValidateFlags(); +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h new file mode 100644 index 0000000000000000000000000000000000000000..d48f693693c2cee0cd2e2a6f2b4c590998feffb3 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" +#include "tensorflow/core/util/stats_calculator.h" + +namespace tflite { +namespace benchmark { + +enum RunType { + WARMUP, + REGULAR, +}; + +class BenchmarkResults { + public: + BenchmarkResults(int64_t startup_latency_us, uint64_t input_bytes, + tensorflow::Stat warmup_time_us, + tensorflow::Stat inference_time_us) + : startup_latency_us_(startup_latency_us), + input_bytes_(input_bytes), + warmup_time_us_(warmup_time_us), + inference_time_us_(inference_time_us) {} + + tensorflow::Stat inference_time_us() const { + return inference_time_us_; + } + tensorflow::Stat warmup_time_us() const { return warmup_time_us_; } + int64_t startup_latency_us() const { return startup_latency_us_; } + uint64_t input_bytes() const { return input_bytes_; } + double throughput_MB_per_second() const { + double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) / + inference_time_us_.sum(); + return bytes_per_sec / (1024.0 * 1024.0); + } + + private: + int64_t startup_latency_us_; + uint64_t input_bytes_; + tensorflow::Stat warmup_time_us_; + tensorflow::Stat inference_time_us_; +}; + +struct BenchmarkParams { + BenchmarkParams() + : num_runs(50), warmup_runs(1), run_delay(-1.0), num_threads(1) {} + int num_runs; + int warmup_runs; + float run_delay; + int num_threads; + std::string benchmark_name; + std::string output_prefix; +}; + +class BenchmarkListener { + public: + virtual void OnBenchmarkStart(const BenchmarkParams& params) {} + virtual void OnSingleRunStart(RunType runType) {} + virtual void OnSingleRunEnd() {} + virtual void OnBenchmarkEnd(const BenchmarkResults& results) {} + virtual ~BenchmarkListener() {} +}; + +// A listener that forwards its method calls to a collection of listeners. +class BenchmarkListeners : public BenchmarkListener { + public: + // Added a listener to the listener collection. + // |listener| is not owned by the instance of |BenchmarkListeners|. + // |listener| should not be null and should outlast the instance of + // |BenchmarkListeners|. + void AddListener(BenchmarkListener* listener) { + listeners_.push_back(listener); + } + + void OnBenchmarkStart(const BenchmarkParams& params) override { + for (auto listener : listeners_) { + listener->OnBenchmarkStart(params); + } + } + + void OnSingleRunStart(RunType runType) override { + for (auto listener : listeners_) { + listener->OnSingleRunStart(runType); + } + } + + void OnSingleRunEnd() override { + for (auto listener : listeners_) { + listener->OnSingleRunEnd(); + } + } + + void OnBenchmarkEnd(const BenchmarkResults& results) override { + for (auto listener : listeners_) { + listener->OnBenchmarkEnd(results); + } + } + + ~BenchmarkListeners() {} + + private: + // Use vector so listeners are invoked in the order they are added. + std::vector listeners_; +}; + +// Benchmark listener that just logs the results of benchmark run. +class BenchmarkLoggingListener : public BenchmarkListener { + void OnBenchmarkEnd(const BenchmarkResults& results) override; +}; + +// Benchmarks a model. +// +// Subclasses need to implement initialization and running of the model. +// The results can be collected by adding BenchmarkListener(s). +class BenchmarkModel { + public: + virtual ~BenchmarkModel() {} + bool ParseFlags(int argc, char** argv); + virtual void Init() = 0; + void Run(int argc, char** argv); + void AddListener(BenchmarkListener* listener) { + listeners_.AddListener(listener); + } + + protected: + virtual void LogFlags(); + virtual bool ValidateFlags() { return true; } + virtual std::vector GetFlags(); + virtual uint64_t ComputeInputBytes() = 0; + virtual tensorflow::Stat Run(int num_times, RunType run_type); + virtual void RunImpl() = 0; + BenchmarkParams params_; + BenchmarkListeners listeners_; +}; + +} // namespace benchmark +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f803cec197858953180d379c763ed7ebd34ee1d --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -0,0 +1,304 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +#ifdef TFLITE_CUSTOM_OPS_HEADER +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); +#endif + +namespace tflite { +namespace benchmark { + +void ProfilingListener::SetInterpreter(tflite::Interpreter* interpreter) { + TFLITE_BENCHMARK_CHECK(interpreter); + interpreter_ = interpreter; + interpreter_->SetProfiler(&profiler_); +} + +void ProfilingListener::OnSingleRunStart(RunType run_type) { + if (run_type == REGULAR) { + profiler_.Reset(); + profiler_.StartProfiling(); + } +} + +void ProfilingListener::OnBenchmarkEnd(const BenchmarkResults& results) { + if (has_profiles_) { + TFLITE_LOG(INFO) << summarizer_.GetOutputString(); + } +} + +void ProfilingListener::OnSingleRunEnd() { + profiler_.StopProfiling(); + auto profile_events = profiler_.GetProfileEvents(); + has_profiles_ = !profile_events.empty(); + summarizer_.ProcessProfiles(profile_events, *interpreter_); +} + +namespace { + +std::vector Split(const std::string& str, const char delim) { + std::istringstream input(str); + std::vector results; + std::string item; + while (std::getline(input, item, delim)) { + results.push_back(item); + } + return results; +} + +template +bool SplitAndParse(const std::string& str, char delim, std::vector* values) { + std::istringstream input(str); + bool first = true; + while (!input.eof()) { + if (!first) { + char c; + input >> c; + if (c != delim) { + return false; + } + } else { + first = false; + } + T val; + input >> val; + if (!input.eof() && !input.good()) { + return false; + } + values->push_back(val); + } + return true; +} + +template +void FillRandomValue(T* ptr, const std::vector& sizes, + const std::function& random_func) { + int num_elements = 1; + for (int dim : sizes) { + num_elements *= dim; + } + for (int i = 0; i < num_elements; ++i) { + *ptr++ = random_func(); + } +} + +void FillRandomString(tflite::DynamicBuffer* buffer, + const std::vector& sizes, + const std::function& random_func) { + int num_elements = 1; + for (int dim : sizes) { + num_elements *= dim; + } + for (int i = 0; i < num_elements; ++i) { + auto str = random_func(); + buffer->AddString(str.data(), str.length()); + } +} + +bool PopulateInputLayerInfo( + const string& names_string, const string& shapes_string, + std::vector* info) { + std::vector names = Split(names_string, ','); + std::vector shapes = Split(shapes_string, ':'); + + if (names.size() != shapes.size()) { + TFLITE_LOG(ERROR) << "The number of items in" + << " --input_layer_shape (" << shapes_string << ", with " + << shapes.size() << " items)" + << " must match the number of items in" + << " --input_layer (" << names_string << ", with " + << names.size() << " items)." + << " For example --input_layer=input1,input2" + << " --input_layer_shape=1,224,224,4:1,20"; + return false; + } + + for (int i = 0; i < names.size(); ++i) { + info->push_back(BenchmarkTfLiteModel::InputLayerInfo()); + BenchmarkTfLiteModel::InputLayerInfo& input = info->back(); + + input.name = names[i]; + + TFLITE_BENCHMARK_CHECK(SplitAndParse(shapes[i], ',', &input.shape)) + << "Incorrect size string specified: " << shapes[i]; + for (int dim : input.shape) { + if (dim == -1) { + TFLITE_LOG(ERROR) + << "Any unknown sizes in the shapes (-1's) must be replaced" + << " with the size you want to benchmark with."; + return false; + } + } + } + + return true; +} + +} // namespace + +std::vector BenchmarkTfLiteModel::GetFlags() { + std::vector flags = BenchmarkTfLiteModel::BenchmarkModel::GetFlags(); + std::vector specific_flags = { + Flag("graph", &graph, "graph file name"), + Flag("input_layer", &input_layer_string, "input layer names"), + Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), + Flag("use_nnapi", &use_nnapi, "use nnapi api")}; + + flags.insert(flags.end(), specific_flags.begin(), specific_flags.end()); + return flags; +} + +void BenchmarkTfLiteModel::LogFlags() { + BenchmarkModel::LogFlags(); + TFLITE_LOG(INFO) << "Graph: [" << graph << "]"; + TFLITE_LOG(INFO) << "Input layers: [" << input_layer_string << "]"; + TFLITE_LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]"; + TFLITE_LOG(INFO) << "Use nnapi : [" << use_nnapi << "]"; +} + +bool BenchmarkTfLiteModel::ValidateFlags() { + if (graph.empty()) { + TFLITE_LOG(ERROR) + << "Please specify the name of your TF Lite input file with --graph"; + return false; + } + return PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, + &inputs); +} + +uint64_t BenchmarkTfLiteModel::ComputeInputBytes() { + TFLITE_BENCHMARK_CHECK(interpreter); + uint64_t total_input_bytes = 0; + for (int input : interpreter->inputs()) { + auto* t = interpreter->tensor(input); + total_input_bytes += t->bytes; + } + return total_input_bytes; +} + +void BenchmarkTfLiteModel::Init() { + model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); + if (!model) { + TFLITE_LOG(FATAL) << "Failed to mmap model " << graph; + } + TFLITE_LOG(INFO) << "Loaded model " << graph; + model->error_reporter(); + TFLITE_LOG(INFO) << "resolved reporter"; + +#ifdef TFLITE_CUSTOM_OPS_HEADER + tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); +#else + tflite::ops::builtin::BuiltinOpResolver resolver; +#endif + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + TFLITE_LOG(FATAL) << "Failed to construct interpreter"; + } + profiling_listener_.SetInterpreter(interpreter.get()); + + if (params_.num_threads != -1) { + interpreter->SetNumThreads(params_.num_threads); + } + + interpreter->UseNNAPI(use_nnapi); + auto interpreter_inputs = interpreter->inputs(); + + if (!inputs.empty()) { + TFLITE_BENCHMARK_CHECK_EQ(inputs.size(), interpreter_inputs.size()) + << "Inputs mismatch: Model inputs #:" << interpreter_inputs.size() + << " expected: " << inputs.size(); + } + + // TFLITE_BENCHMARK_CHECK that all names and types match + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + TFLITE_BENCHMARK_CHECK_EQ(t->name, input.name) + << "Tensor # " << i << " is named " << t->name << " but flags call it " + << input.name; + } + + // Resize all non-string tensors. + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + if (t->type != kTfLiteString) { + interpreter->ResizeInputTensor(i, input.shape); + } + } + + if (interpreter->AllocateTensors() != kTfLiteOk) { + TFLITE_LOG(FATAL) << "Failed to allocate tensors!"; + } + + // Set the values of the input tensors. + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + std::vector sizes = input.shape; + + // TODO(ahentz): below we ignore the O-th dimension (number of batches). + if (t->type == kTfLiteFloat32) { + FillRandomValue( + interpreter->typed_tensor(i), + std::vector(sizes.begin() + 1, sizes.end()), + []() { return static_cast(rand()) / RAND_MAX - 0.5f; }); + } else if (t->type == kTfLiteUInt8) { + FillRandomValue( + interpreter->typed_tensor(i), + std::vector(sizes.begin() + 1, sizes.end()), + []() { return static_cast(rand()) % 255; }); + } else if (t->type == kTfLiteString) { + tflite::DynamicBuffer buffer; + FillRandomString(&buffer, sizes, []() { + return "we're have some friends over saturday to hang out in the yard"; + }); + buffer.WriteToTensor(interpreter->tensor(i)); + } else { + TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name + << " of type " << t->type; + } + } +} + +void BenchmarkTfLiteModel::RunImpl() { + if (interpreter->Invoke() != kTfLiteOk) { + TFLITE_LOG(FATAL) << "Failed to invoke!"; + } +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h new file mode 100644 index 0000000000000000000000000000000000000000..ffb93da964b2da0328616e749abd9c5a84189468 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -0,0 +1,86 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" + +namespace tflite { +namespace benchmark { + +// Dumps profiling events if profiling is enabled +class ProfilingListener : public BenchmarkListener { + public: + explicit ProfilingListener() : interpreter_(nullptr), has_profiles_(false) {} + + void SetInterpreter(Interpreter* interpreter); + + void OnSingleRunStart(RunType run_type) override; + + void OnSingleRunEnd() override; + + void OnBenchmarkEnd(const BenchmarkResults& results) override; + + private: + Interpreter* interpreter_; + profiling::Profiler profiler_; + profiling::ProfileSummarizer summarizer_; + bool has_profiles_; +}; + +// Benchmarks a TFLite model by running tflite interpreter. +class BenchmarkTfLiteModel : public BenchmarkModel { + public: + BenchmarkTfLiteModel() : use_nnapi(false) { + AddListener(&profiling_listener_); + } + + std::vector GetFlags() override; + void LogFlags() override; + bool ValidateFlags() override; + uint64_t ComputeInputBytes() override; + void Init() override; + void RunImpl() override; + virtual ~BenchmarkTfLiteModel() {} + + struct InputLayerInfo { + std::string name; + std::vector shape; + }; + + private: + std::unique_ptr model; + std::unique_ptr interpreter; + std::string graph; + std::string input_layer_string; + std::string input_layer_type_string; + std::string input_layer_shape_string; + std::string input_layer_values_string; + std::vector inputs; + bool use_nnapi; + ProfilingListener profiling_listener_; +}; + +} // namespace benchmark +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..8195fc44beb288eec3c020791b47eefa01536fb7 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc @@ -0,0 +1,194 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" + +#include +#include +#include +#include + +namespace tflite { +namespace { + +template +std::string ToString(T val) { + std::ostringstream stream; + stream << val; + return stream.str(); +} + +bool ParseFlag(const std::string& arg, const std::string& flag, + const std::function& parse_func, + bool* value_parsing_ok) { + *value_parsing_ok = true; + std::string flag_prefix = "--" + flag + "="; + if (arg.find(flag_prefix) != 0) { + return false; + } + bool has_value = arg.size() >= flag_prefix.size(); + *value_parsing_ok = has_value; + if (has_value) { + *value_parsing_ok = parse_func(arg.substr(flag_prefix.size())); + } + return true; +} + +template +bool ParseFlag(const std::string& flag_value, T* value) { + std::istringstream stream(flag_value); + T read_value; + stream >> read_value; + if (!stream.eof() && !stream.good()) { + return false; + } + *value = read_value; + return true; +} + +bool ParseBoolFlag(const std::string& flag_value, bool* value) { + if (flag_value != "true" && flag_value != "false") { + return false; + } + + *value = (flag_value == "true"); + return true; +} + +bool ParseStringFlag(const std::string& flag_value, std::string* value) { + *value = flag_value; + return true; +} + +} // namespace + +Flag::Flag(const char* name, int32_t* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_INT32), + value_hook_([dst](const std::string& flag_value) { + return ParseFlag(flag_value, dst); + }), + default_for_display_(ToString(*dst)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, int64_t* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_INT64), + value_hook_([dst](const std::string& flag_value) { + return ParseFlag(flag_value, dst); + }), + default_for_display_(ToString(*dst)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, float* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_FLOAT), + value_hook_([dst](const std::string& flag_value) { + return ParseFlag(flag_value, dst); + }), + default_for_display_(ToString(*dst)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, bool* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_BOOL), + value_hook_([dst](const std::string& flag_value) { + return ParseBoolFlag(flag_value, dst); + }), + default_for_display_((*dst) ? "true" : "false"), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::string* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_STRING), + value_hook_([dst](const std::string& flag_value) { + return ParseStringFlag(flag_value, dst); + }), + default_for_display_(*dst), + usage_text_(usage_text) {} + +bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const { + return ParseFlag(arg, name_, value_hook_, value_parsing_ok); +} + +std::string Flag::GetTypeName() const { + switch (type_) { + case TYPE_INT32: + return "int32"; + case TYPE_INT64: + return "int64"; + case TYPE_FLOAT: + return "float"; + case TYPE_BOOL: + return "bool"; + case TYPE_STRING: + return "string"; + } + + return "unknown"; +} + +/*static*/ bool Flags::Parse(int* argc, const char** argv, + const std::vector& flag_list) { + bool result = true; + std::vector unknown_flags; + for (int i = 1; i < *argc; ++i) { + if (std::string(argv[i]) == "--") { + while (i < *argc) { + unknown_flags.push_back(argv[i]); + ++i; + } + break; + } + + bool was_found = false; + for (const Flag& flag : flag_list) { + bool value_parsing_ok; + was_found = flag.Parse(argv[i], &value_parsing_ok); + if (!value_parsing_ok) { + result = false; + } + if (was_found) { + break; + } + } + if (!was_found) { + unknown_flags.push_back(argv[i]); + } + } + int dst = 1; // Skip argv[0] + for (auto f : unknown_flags) { + argv[dst++] = f; + } + argv[dst++] = nullptr; + *argc = unknown_flags.size() + 1; + return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0); +} + +/*static*/ std::string Flags::Usage(const std::string& cmdline, + const std::vector& flag_list) { + std::ostringstream usage_text; + usage_text << "usage: " << cmdline << "\n"; + if (!flag_list.empty()) { + usage_text << "Flags:\n"; + } + + for (const Flag& flag : flag_list) { + auto type_name = flag.GetTypeName(); + usage_text << "\t"; + usage_text << "--" << flag.name_ << "=" << flag.default_for_display_; + usage_text << "\t" << type_name << "\t" << flag.usage_text_ << "\n"; + } + return usage_text.str(); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..36f9e64767315a317338bc4d2db2ec2d43bee875 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ + +#include +#include +#include + +namespace tflite { +// A simple command-line argument parsing module. +// Dependency free simplified port of core/util/command_line_flags. +// This class is written for benchmarks and uses inefficient string +// concatenation. This was written to avoid dependency on tensorflow/core/util +// which transitively brings in a lot of other dependencies that are not +// necessary for tflite benchmarking code. +// The recommended way of using it is with local variables and an initializer +// list of Flag objects, for example: +// +// int some_int = 10; +// bool some_switch = false; +// std::string some_name = "something"; +// std::vector flag_list = { +// Flag("some_int", &some_int, "an integer that affects X"), +// Flag("some_switch", &some_switch, "a bool that affects Y"), +// Flag("some_name", &some_name, "a std::string that affects Z") +// }; +// // Get usage message before ParseFlags() to capture default values. +// std::string usage = Flag::Usage(argv[0], flag_list); +// bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list); +// +// tensorflow::port::InitMain(usage.c_str(), &argc, &argv); +// if (argc != 1 || !parsed_values_ok) { +// ...output usage and error message... +// } +// +// The argc and argv values are adjusted by the Parse function so all that +// remains is the program name (at argv[0]) and any unknown arguments fill the +// rest of the array. This means you can check for flags that weren't understood +// by seeing if argv is greater than 1. +// The result indicates if there were any errors parsing the values that were +// passed to the command-line switches. For example, --some_int=foo would return +// false because the argument is expected to be an integer. +// +// NOTE: Unlike gflags-style libraries, this library is intended to be +// used in the `main()` function of your binary. It does not handle +// flag definitions that are scattered around the source code. + +// A description of a single command line flag, holding its name, type, usage +// text, and a pointer to the corresponding variable. +class Flag { + public: + Flag(const char* name, int32_t* dst, const std::string& usage_text); + Flag(const char* name, int64_t* dst, const std::string& usage_text); + Flag(const char* name, bool* dst, const std::string& usage_text); + Flag(const char* name, std::string* dst, const std::string& usage_text); + Flag(const char* name, float* dst, const std::string& usage_text); + + private: + friend class Flags; + + bool Parse(const std::string& arg, bool* value_parsing_ok) const; + + std::string name_; + enum { + TYPE_INT32, + TYPE_INT64, + TYPE_BOOL, + TYPE_STRING, + TYPE_FLOAT, + } type_; + + std::string GetTypeName() const; + + std::function value_hook_; + std::string default_for_display_; + + std::string usage_text_; +}; + +class Flags { + public: + // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag + // instances matching flags in flaglist[]. Update the variables associated + // with matching flags, and remove the matching arguments from (*argc, argv). + // Return true iff all recognized flag values were parsed correctly, and the + // first remaining argument is not "--help". + static bool Parse(int* argc, const char** argv, + const std::vector& flag_list); + + // Return a usage message with command line cmdline, and the + // usage_text strings in flag_list[]. + static std::string Usage(const std::string& cmdline, + const std::vector& flag_list); +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..620d61b027d30044ba9d449a8e308375f72ad76f --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc @@ -0,0 +1,166 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" +#include +#include +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace { + +TEST(CommandLineFlagsTest, BasicUsage) { + int some_int32 = 10; + int64_t some_int64 = 21474836470; // max int32 is 2147483647 + bool some_switch = false; + std::string some_name = "something_a"; + float some_float = -23.23f; + const char* argv_strings[] = {"program_name", + "--some_int32=20", + "--some_int64=214748364700", + "--some_switch=true", + "--some_name=somethingelse", + "--some_float=42.0"}; + int argc = 6; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + { + Flag("some_int32", &some_int32, "some int32"), + Flag("some_int64", &some_int64, "some int64"), + Flag("some_switch", &some_switch, "some switch"), + Flag("some_name", &some_name, "some name"), + Flag("some_float", &some_float, "some float"), + }); + + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(20, some_int32); + EXPECT_EQ(214748364700, some_int64); + EXPECT_EQ(true, some_switch); + EXPECT_EQ("somethingelse", some_name); + EXPECT_NEAR(42.0f, some_float, 1e-5f); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, EmptyStringFlag) { + int argc = 2; + std::string some_string = "invalid"; + const char* argv_strings[] = {"program_name", "--some_string="}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_string", &some_string, "some string")}); + + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(some_string, ""); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadIntValue) { + int some_int = 10; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_int=notanumber"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_int", &some_int, "some int")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(10, some_int); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadBoolValue) { + bool some_switch = false; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_switch=notabool"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_switch", &some_switch, "some switch")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(false, some_switch); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadFloatValue) { + float some_float = -23.23f; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_float=notanumber"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_float", &some_float, "some float")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_NEAR(-23.23f, some_float, 1e-5f); + EXPECT_EQ(argc, 1); +} + +// Return whether str==pat, but allowing any whitespace in pat +// to match zero or more whitespace characters in str. +static bool MatchWithAnyWhitespace(const std::string& str, + const std::string& pat) { + bool matching = true; + int pat_i = 0; + for (int str_i = 0; str_i != str.size() && matching; str_i++) { + if (isspace(str[str_i])) { + matching = (pat_i != pat.size() && isspace(pat[pat_i])); + } else { + while (pat_i != pat.size() && isspace(pat[pat_i])) { + pat_i++; + } + matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]); + } + } + while (pat_i != pat.size() && isspace(pat[pat_i])) { + pat_i++; + } + return (matching && pat_i == pat.size()); +} + +TEST(CommandLineFlagsTest, UsageString) { + int some_int = 10; + int64_t some_int64 = 21474836470; // max int32 is 2147483647 + bool some_switch = false; + std::string some_name = "something"; + // Don't test float in this case, because precision is hard to predict and + // match against, and we don't want a flakey test. + const std::string tool_name = "some_tool_name"; + std::string usage = Flags::Usage( + tool_name + " ", {Flag("some_int", &some_int, "some int"), + Flag("some_int64", &some_int64, "some int64"), + Flag("some_switch", &some_switch, "some switch"), + Flag("some_name", &some_name, "some name")}); + // Match the usage message, being sloppy about whitespace. + const char* expected_usage = + " usage: some_tool_name \n" + "Flags:\n" + "--some_int=10\tint32\tsome int\n" + "--some_int64=21474836470\tint64\tsome int64\n" + "--some_switch=false\tbool\tsome switch\n" + "--some_name=something\tstring\tsome name\n"; + ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true) << usage; + + // Again but with no flags. + usage = Flags::Usage(tool_name, {}); + ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true) + << usage; +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/benchmark/logging.h b/tensorflow/contrib/lite/tools/benchmark/logging.h new file mode 100644 index 0000000000000000000000000000000000000000..9e9292e2feacf0eff0751534f02cdacd21c9b0dd --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/logging.h @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_ + +// LOG and CHECK macros for benchmarks. + +#include +#include +#include + +namespace tflite { +namespace logging { +// A wrapper that logs to stderr. +// +// Used for TFLITE_LOG and TFLITE_BENCHMARK_CHECK macros. +class LoggingWrapper { + public: + enum class LogSeverity : int { + INFO = 0, + WARN = 1, + ERROR = 2, + FATAL = 3, + }; + LoggingWrapper(LogSeverity severity) + : severity_(severity), should_log_(true) {} + LoggingWrapper(LogSeverity severity, bool log) + : severity_(severity), should_log_(log) {} + std::stringstream& Stream() { return stream_; } + ~LoggingWrapper() { + if (should_log_) { + std::cerr << stream_.str() << std::endl; + if (severity_ == LogSeverity::FATAL) { + std::flush(std::cerr); + std::abort(); + } + } + } + + private: + std::stringstream stream_; + LogSeverity severity_; + bool should_log_; +}; + +} // namespace logging + +} // namespace tflite + +#define TFLITE_LOG(severity) \ + tflite::logging::LoggingWrapper( \ + tflite::logging::LoggingWrapper::LogSeverity::severity) \ + .Stream() + +#define TFLITE_BENCHMARK_CHECK(condition) \ + tflite::logging::LoggingWrapper( \ + tflite::logging::LoggingWrapper::LogSeverity::FATAL, \ + (condition) ? false : true) \ + .Stream() + +#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b) + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_LOGGING_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark_model.cc deleted file mode 100644 index 93c80e0f5e021f76bff6858b0ea3370724393d6d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/tools/benchmark_model.cc +++ /dev/null @@ -1,475 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/command_line_flags.h" - -#ifdef TFLITE_CUSTOM_OPS_HEADER -void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); -#endif - -namespace tflite { - -using ::tensorflow::Env; -using ::tensorflow::str_util::Split; -using ::tensorflow::str_util::SplitAndParseAsFloats; -using ::tensorflow::str_util::SplitAndParseAsInts; - -struct InputLayerInfo { - string name; - TfLiteType data_type; - std::vector shape; - // Note that initialization_values is currently unused. - std::vector initialization_values; -}; - -template -void FillRandomValue(T* ptr, const std::vector& sizes, - const std::function& random_func) { - int num_elements = 1; - for (int dim : sizes) { - num_elements *= dim; - } - for (int i = 0; i < num_elements; ++i) { - *ptr++ = random_func(); - } -} - -void FillRandomString(tflite::DynamicBuffer* buffer, - const std::vector& sizes, - const std::function& random_func) { - int num_elements = 1; - for (int dim : sizes) { - num_elements *= dim; - } - for (int i = 0; i < num_elements; ++i) { - auto str = random_func(); - buffer->AddString(str.data(), str.length()); - } -} - -TfLiteType TfLiteTypeFromString(const string& input_layer_type) { - if (input_layer_type == "string") - return kTfLiteString; - else if (input_layer_type == "float") - return kTfLiteFloat32; - else if (input_layer_type == "uint8") - return kTfLiteUInt8; - else if (input_layer_type == "int32") - return kTfLiteInt32; - else if (input_layer_type == "int64") - return kTfLiteInt64; - else - return kTfLiteNoType; -} - -std::vector ShapeFromTfLiteTensor(TfLiteTensor* t) { - std::vector result; - result.reserve(t->dims->size); - for (int i = 0; i < t->dims->size; ++i) { - result.push_back(t->dims->data[i]); - } - CHECK(!result.empty()) << "Found no shapes in model"; - return result; -} - -bool CreateInterpreter(const string& graph, - std::unique_ptr* model, - std::unique_ptr* interpreter) { - *model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); - if (!model) { - std::cerr << "Failed to load model " << graph << std::endl; - return false; - } - -#ifdef TFLITE_CUSTOM_OPS_HEADER - tflite::MutableOpResolver resolver; - RegisterSelectedOps(&resolver); -#else - tflite::ops::builtin::BuiltinOpResolver resolver; -#endif - - tflite::InterpreterBuilder(*(model->get()), resolver)(interpreter); - if (!(*interpreter)) { - std::cerr << "Failed to construct interpreter" << std::endl; - return false; - } - - return true; -} - -bool PrepareInterpreter(const std::vector inputs, - int num_threads, bool use_nnapi, - Interpreter* interpreter) { - if (num_threads != -1) { - interpreter->SetNumThreads(num_threads); - } - - interpreter->UseNNAPI(use_nnapi); - - // Check that all names and types match - for (const InputLayerInfo& input : inputs) { - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - CHECK_EQ(t->name, input.name) - << "Tensor # " << i << " is named " << t->name - << " but flags call it " << input.name; - CHECK_EQ(t->type, input.data_type) - << "Could not match the type of input tensor " << t->name; - } - } - - // Resize all non-string tensors. - for (const InputLayerInfo& input : inputs) { - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - if (t->type != kTfLiteString) { - interpreter->ResizeInputTensor(i, input.shape); - } - } - } - - if (interpreter->AllocateTensors() != kTfLiteOk) { - std::cerr << "Failed to allocate tensors!" << std::endl; - return false; - } - - // Set the values of the input tensors. - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - std::vector sizes = ShapeFromTfLiteTensor(t); - - // TODO(ahentz): below we ignore the O-th dimension (number of batches). - if (t->type == kTfLiteFloat32) { - FillRandomValue( - interpreter->typed_tensor(i), - std::vector(sizes.begin() + 1, sizes.end()), - []() { return static_cast(rand()) / RAND_MAX - 0.5f; }); - } else if (t->type == kTfLiteUInt8) { - FillRandomValue( - interpreter->typed_tensor(i), - std::vector(sizes.begin() + 1, sizes.end()), - []() { return static_cast(rand()) % 255; }); - } else if (t->type == kTfLiteString) { - tflite::DynamicBuffer buffer; - FillRandomString(&buffer, sizes, []() { - return "we're have some friends over saturday to hang out in the yard"; - }); - buffer.WriteToTensor(interpreter->tensor(i)); - } else { - std::cerr << "Don't know how to populate tensor " << t->name - << " of type " << t->type << std::endl; - return false; - } - } - return true; -} - -bool PopulateInputLayerInfo(const string& names_string, - const string& shapes_string, - const string& types_string, - const string& values_string, - std::vector* info) { - std::vector names = Split(names_string, ','); - std::vector shapes = Split(shapes_string, ':'); - std::vector types = Split(types_string, ','); - std::vector values = Split(values_string, ':'); - - if (names.size() != shapes.size()) { - LOG(ERROR) << "The number of items in" - << " --input_layer_shape (" << shapes_string << ", with " - << shapes.size() << " items)" - << " must match the number of items in" - << " --input_layer (" << names_string << ", with " - << names.size() << " items)." - << " For example --input_layer=input1,input2" - << " --input_layer_shape=1,224,224,4:1,20"; - return false; - } - if (names.size() != types.size()) { - LOG(ERROR) << "The number of items in" - << " --input_layer_type (" << types_string << ", with " - << types.size() << " items)" - << " must match the number of items in" - << " --input_layer (" << names_string << ", with " - << names.size() << " items)." - << " For example --input_layer=input1,input2" - << " --input_layer_type=float,int"; - return false; - } - - for (int i = 0; i < names.size(); ++i) { - info->push_back(InputLayerInfo()); - InputLayerInfo& input = info->back(); - - input.name = names[i]; - - input.data_type = TfLiteTypeFromString(types[i]); - CHECK(input.data_type != kTfLiteNoType) - << types[i] << " was an invalid type"; - - CHECK(SplitAndParseAsInts(shapes[i], ',', &input.shape)) - << "Incorrect size string specified: " << shapes[i]; - for (int dim : input.shape) { - if (dim == -1) { - LOG(ERROR) << "Any unknown sizes in the shapes (-1's) must be replaced" - << " with the size you want to benchmark with."; - return false; - } - } - - if (i < values.size()) { - CHECK(SplitAndParseAsFloats(values[i], ',', &input.initialization_values)) - << "Incorrect initialization values string specified: " << values[i]; - } - } - - return true; -} - -bool RunBenchmark(Interpreter* interpreter, int64_t* inference_time_us) { - const int64_t start_time = Env::Default()->NowMicros(); - - if (interpreter->Invoke() != kTfLiteOk) { - std::cerr << "Failed to invoke!"; - return false; - } - - const int64_t end_time = Env::Default()->NowMicros(); - *inference_time_us = end_time - start_time; - return true; -} - -class Latencies { - public: - void AddMeasurement(int64_t time_us) { - max_ = std::max(time_us, max_); - min_ = std::min(time_us, min_); - ++count_; - sum_ += time_us; - squared_sum_ += static_cast(time_us) * time_us; - } - - double avg() const { - if (count_ == 0) return std::numeric_limits::quiet_NaN(); - return static_cast(sum_) / count_; - } - - int64_t std_deviation() const { - if (count_ == 0 || min_ == max_) return 0; - return sqrt(squared_sum_ / count_ - avg() * avg()); - } - - void OutputToStream(std::ostream* stream) const { - *stream << "count=" << count_; - if (count_ == 0) return; - *stream << " min=" << min_ << " max=" << max_; - *stream << " avg=" << avg() << " std=" << std_deviation(); - } - - private: - int64_t count_ = 0; - int64_t min_ = std::numeric_limits::max(); - int64_t max_ = std::numeric_limits::min(); - int64_t sum_ = 0; - double squared_sum_ = 0; -}; - -bool TimeMultipleRuns(Interpreter* interpreter, double sleep_seconds, - int num_runs, int64* total_time_us) { - // Convert the run_delay string into a timespec. - timespec req; - req.tv_sec = static_cast(sleep_seconds); - req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; - - *total_time_us = 0; - - std::cout << "Running benchmark for " << num_runs - << " iterations: " << std::endl; - - Latencies latencies; - for (int i = 0; i < num_runs; ++i) { - int64_t time_us; - bool run_status = RunBenchmark(interpreter, &time_us); - latencies.AddMeasurement(time_us); - *total_time_us += time_us; - if (!run_status) { - std::cout << "Failed on run " << i << std::endl; - return false; - } - - // If requested, sleep between runs for an arbitrary amount of time. - // This can be helpful to determine the effect of mobile processor - // scaling and thermal throttling. - if (sleep_seconds > 0.0) { -#ifdef PLATFORM_WINDOWS - Sleep(sleep_seconds * 1000); -#else - nanosleep(&req, nullptr); -#endif - } - } - latencies.OutputToStream(&std::cout); - std::cout << std::endl; - - return true; -} - -int Main(int argc, char** argv) { - using tensorflow::Flag; - using tensorflow::Flags; - - string graph; // e.g.: /data/local/tmp/tfl_inception-v1_model.fb - string input_layer_string; // e.g.: input - string input_layer_shape_string; // e.g.: 1,224,224,3 - string input_layer_type_string; // e.g.: float - string input_layer_values_string; - string output_layer_string; // e.g.: output - int num_runs = 50; - string run_delay = "-1.0"; - int num_threads = -1; - string benchmark_name = ""; - string output_prefix = ""; - int warmup_runs = 1; - bool use_nnapi = false; - - std::vector flag_list = { - Flag("graph", &graph, "graph file name"), - // All the following flags are optional, but can be used in order - // to benchmark different input shapes. - Flag("input_layer", &input_layer_string, "input layer names"), - Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), - Flag("input_layer_type", &input_layer_type_string, "input layer type"), - Flag("input_layer_values", &input_layer_values_string, - "values to initialize the inputs with"), - Flag("output_layer", &output_layer_string, "output layer name"), - Flag("num_runs", &num_runs, "number of runs"), - Flag("run_delay", &run_delay, "delay between runs in seconds"), - Flag("num_threads", &num_threads, "number of threads"), - Flag("benchmark_name", &benchmark_name, "benchmark name"), - Flag("output_prefix", &output_prefix, "benchmark output prefix"), - Flag("warmup_runs", &warmup_runs, "how many runs to initialize model"), - Flag("use_nnapi", &use_nnapi, "use nnapi api"), - }; - string usage = Flags::Usage(argv[0], flag_list); - const bool parse_result = Flags::Parse(&argc, argv, flag_list); - tensorflow::port::InitMain(argv[0], &argc, &argv); - - if (!parse_result) { - std::cerr << usage << std::endl; - return -1; - } - - std::cout << "Graph: [" << graph << "]" << std::endl; - if (!input_layer_string.empty()) { - std::cout << "Input layers: [" << input_layer_string << "]" << std::endl; - std::cout << "Input shapes: [" << input_layer_shape_string << "]" - << std::endl; - std::cout << "Input types: [" << input_layer_type_string << "]" - << std::endl; - } - if (!output_layer_string.empty()) { - std::cout << "Output layers: [" << output_layer_string << "]" << std::endl; - } - std::cout << "Num runs: [" << num_runs << "]" << std::endl; - std::cout << "Inter-run delay (seconds): [" << run_delay << "]" << std::endl; - std::cout << "Num threads: [" << num_threads << "]" << std::endl; - if (!benchmark_name.empty()) { - std::cout << "Benchmark name: [" << benchmark_name << "]" << std::endl; - std::cout << "Output prefix: [" << output_prefix << "]" << std::endl; - } - std::cout << "Warmup runs: [" << warmup_runs << "]" << std::endl; - std::cout << "Use nnapi : [" << use_nnapi << "]" << std::endl; - - if (graph.empty()) { - std::cout - << "Please specify the name of your TF Lite input file with --graph" - << std::endl; - return -1; - } - - std::vector inputs; - if (!PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, - input_layer_type_string, - input_layer_values_string, &inputs)) { - return -1; - } - - int64 initialization_start_us = Env::Default()->NowMicros(); - - std::unique_ptr model; - std::unique_ptr interpreter; - if (!CreateInterpreter(graph, &model, &interpreter)) { - return -1; - } - if (!PrepareInterpreter(inputs, num_threads, use_nnapi, interpreter.get())) { - return -1; - } - - int64 initialization_end_us = Env::Default()->NowMicros(); - - const double initialization_time_s = - (initialization_end_us - initialization_start_us) / 1000000.0f; - std::cout << "Initialized session in " << initialization_time_s << "s" - << std::endl; - - const double sleep_seconds = std::strtod(run_delay.c_str(), nullptr); - - // If requested, run through the graph first to preinitialize everything - // before the benchmarking runs. - int64 warmup_time_us = 0; - if (warmup_runs > 0) { - if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, warmup_runs, - &warmup_time_us)) { - std::cerr << "Warmup failed" << std::endl; - return -1; - } - } - - // Capture overall inference time without stat logging overhead. This is the - // timing data that can be compared to other libaries. - int64 no_stat_time_us = 0; - if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, num_runs, - &no_stat_time_us)) { - std::cerr << "Timing failed." << std::endl; - return -1; - } - - std::cout << "Average inference timings in us: " << no_stat_time_us / num_runs - << " , Warmup: " - << (warmup_runs > 0 ? warmup_time_us / warmup_runs : 0) << ", " - << std::endl; - - return 0; -} - -} // namespace tflite - -int main(int argc, char** argv) { return ::tflite::Main(argc, argv); } diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc index 17b514c9169817479e18eecf5799ea4371f3b051..f7df80821fc383063c6e19148bfb13801368b334 100644 --- a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc +++ b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc @@ -55,7 +55,7 @@ void GenerateFileContent(const std::string& tflite_path, std::ofstream fout(filename); fout << "#include \"" << tflite_path << "/model.h\"\n"; - fout << "#include \"" << tflite_path << "/tools/mutable_op_resolver.h\"\n"; + fout << "#include \"" << tflite_path << "/op_resolver.h\"\n"; fout << "namespace tflite {\n"; fout << "namespace ops {\n"; diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.cc b/tensorflow/contrib/lite/tools/mutable_op_resolver.cc deleted file mode 100644 index 8a921d7c5aa20ce3a9dc279d8f0c7c253905b078..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/tools/mutable_op_resolver.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" - -namespace tflite { - -TfLiteRegistration* MutableOpResolver::FindOp( - tflite::BuiltinOperator op) const { - auto it = builtins_.find(op); - return it != builtins_.end() ? it->second : nullptr; -} - -TfLiteRegistration* MutableOpResolver::FindOp(const char* op) const { - auto it = custom_ops_.find(op); - return it != custom_ops_.end() ? it->second : nullptr; -} - -void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, - TfLiteRegistration* registration) { - registration->builtin_code = op; - builtins_.insert(std::make_pair(op, registration)); -} - -void MutableOpResolver::AddCustom(const char* name, - TfLiteRegistration* registration) { - registration->builtin_code = BuiltinOperator_CUSTOM; - custom_ops_.insert(std::make_pair(std::string(name), registration)); -} - -} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h deleted file mode 100644 index 573a359c458acb6e4320c5a21cb378cdde720924..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/tools/mutable_op_resolver.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ - -#include -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/model.h" - -// Needed to resolve unordered_set hash on older compilers. -namespace std { -template <> -struct hash { - size_t operator()(const tflite::BuiltinOperator& op) const { - return std::hash()(op); - } -}; -} // namespace std - -namespace tflite { - -// An OpResolver that is mutable, also used as the op in gen_op_registration. -// A typical usage: -// MutableOpResolver resolver; -// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD()); -// resolver.AddCustom("CustomOp", Register_CUSTOM_OP()); -// InterpreterBuilder(model, resolver)(&interpreter); -class MutableOpResolver : public OpResolver { - public: - MutableOpResolver() {} - TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override; - TfLiteRegistration* FindOp(const char* op) const override; - void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration); - void AddCustom(const char* name, TfLiteRegistration* registration); - - private: - std::map builtins_; - std::map custom_ops_; -}; - -} // namespace tflite - -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/tools/verifier.cc b/tensorflow/contrib/lite/tools/verifier.cc index 8818a7dc85d9ffdc1da450fb389d5ed11139bc31..8d3a7a624265ca6f9933f36949fd6fdbb3c39c40 100644 --- a/tensorflow/contrib/lite/tools/verifier.cc +++ b/tensorflow/contrib/lite/tools/verifier.cc @@ -246,15 +246,16 @@ bool VerifyOps(const Model& model, const OpResolver& resolver, } if (opcode->builtin_code() == BuiltinOperator_CUSTOM) { - if (!resolver.FindOp(opcode->custom_code()->c_str())) { - ReportError(error_reporter, "Unsupported custom op: %s", - opcode->custom_code()->c_str()); + if (!resolver.FindOp(opcode->custom_code()->c_str(), opcode->version())) { + ReportError(error_reporter, "Unsupported custom op: %s, version: %d", + opcode->custom_code()->c_str(), opcode->version()); return false; } } else { - if (!resolver.FindOp(opcode->builtin_code())) { - ReportError(error_reporter, "Unsupported builtin op: %s", - EnumNameBuiltinOperator(opcode->builtin_code())); + if (!resolver.FindOp(opcode->builtin_code(), opcode->version())) { + ReportError(error_reporter, "Unsupported builtin op: %s, version: %d", + EnumNameBuiltinOperator(opcode->builtin_code()), + opcode->version()); return false; } } diff --git a/tensorflow/contrib/lite/tools/verifier.h b/tensorflow/contrib/lite/tools/verifier.h index b7ce4e830576af14002d6bd9080af1da5764b1c9..a596c650a0c2533b6ece3cc7c692d863c2d3f860 100644 --- a/tensorflow/contrib/lite/tools/verifier.h +++ b/tensorflow/contrib/lite/tools/verifier.h @@ -26,12 +26,13 @@ namespace tflite { class AlwaysTrueResolver : public OpResolver { public: AlwaysTrueResolver() {} - TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override { + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override { static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr, nullptr}; return &null_registration; } - TfLiteRegistration* FindOp(const char* op) const override { + const TfLiteRegistration* FindOp(const char* op, int version) const override { static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr, nullptr}; return &null_registration; diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc index 03b93afe3ed04b4bff13bc01d7c7c8e9fae9bdf3..ad7d59ecb41a0c81a6a4d8edae5fa6b4b5a7bede 100644 --- a/tensorflow/contrib/lite/tools/verifier_test.cc +++ b/tensorflow/contrib/lite/tools/verifier_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/op_resolver.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/testing/util.h" -#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" #include "tensorflow/contrib/lite/tools/verifier.h" #include "tensorflow/contrib/lite/version.h" #include "tensorflow/core/framework/numeric_types.h" @@ -31,7 +31,6 @@ namespace tflite { using flatbuffers::FlatBufferBuilder; using flatbuffers::Offset; -using flatbuffers::Vector; // Build single subgraph model. class TfLiteFlatbufferModelBuilder { @@ -42,7 +41,7 @@ class TfLiteFlatbufferModelBuilder { } TfLiteFlatbufferModelBuilder(const std::vector& builtin_ops, - const std::vector& custom_ops) { + const std::vector& custom_ops) { buffers_.push_back( CreateBuffer(builder_, builder_.CreateVector(std::vector{}))); @@ -195,8 +194,8 @@ TEST(VerifyModel, TensorBufferIsNotValid) { /*operators=*/0, builder.CreateString("Main"))}); auto buffers = builder.CreateVector(std::vector>{ - CreateBuffer(builder, - builder.CreateVector(std::vector{1, 2, 3, 4, 5, 6})), + CreateBuffer(builder, builder.CreateVector( + std::vector{1, 2, 3, 4, 5, 6})), }); auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, /*operator_codes=*/0, diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc index fb4af07d060cac3a6a4e01c7d625b6db5241f10d..8ccb65c24fd64f05d7e2c888f7932e586c1e11ec 100644 --- a/tensorflow/contrib/lite/util.cc +++ b/tensorflow/contrib/lite/util.cc @@ -38,4 +38,14 @@ bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, return true; } +size_t CombineHashes(std::initializer_list hashes) { + size_t result = 0; + // Hash combiner used by TensorFlow core. + for (size_t hash : hashes) { + result = result ^ + (hash + 0x9e3779b97f4a7800ULL + (result << 10) + (result >> 4)); + } + return result; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h index a34db35823104414cce028b9119397da085d05b1..89d9b4f5cffa99e708f391fd8fe19208009b5e79 100644 --- a/tensorflow/contrib/lite/util.h +++ b/tensorflow/contrib/lite/util.h @@ -35,6 +35,8 @@ TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims); bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, const int* b); +size_t CombineHashes(std::initializer_list hashes); + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_UTIL_H_ diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5d4682ec9f4b8c5864383bd1d2f4c0b41a11baad..5a080cceabb55c307dcd1a457a9e30d24e0bd172 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib import lookup from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -1396,15 +1397,22 @@ class KeyValueTensorInitializerTest(test.TestCase): class IndexTableFromTensor(test.TestCase): + @test_util.run_in_graph_and_eager_modes() def test_index_table_from_tensor_with_tensor_init(self): - with self.test_session(): + table = lookup.index_table_from_tensor( + mapping=("brain", "salad", "surgery"), num_oov_buckets=1) + + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(table.lookup( + constant_op.constant(("salad", "surgery", "tarkus")))) + else: + # Reinitializing a table in eager should work. table = lookup.index_table_from_tensor( mapping=("brain", "salad", "surgery"), num_oov_buckets=1) - ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) - - self.assertRaises(errors_impl.OpError, ids.eval) - lookup_ops.tables_initializer().run() - self.assertAllEqual((1, 2, 3), ids.eval()) + self.evaluate(lookup_ops.tables_initializer()) + ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_int32_index_table_from_tensor_with_tensor_init(self): with self.test_session(): diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index bdad34a665e47a4e060fcaddfffecfdc876a8fb0..651de4e2f446b2da39b000cde2541872116cbdba 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -482,9 +482,12 @@ def hinge_loss(logits, labels=None, scope=None): """Method that returns the loss tensor for hinge loss. Args: - logits: The logits, a float tensor. + logits: The logits, a float tensor. Note that logits are assumed to be + unbounded and 0-centered. A value > 0 (resp. < 0) is considered a positive + (resp. negative) binary prediction. labels: The ground truth output tensor. Its shape should match the shape of - logits. The values of the tensor are expected to be 0.0 or 1.0. + logits. The values of the tensor are expected to be 0.0 or 1.0. Internally + the {0,1} labels are converted to {-1,1} when calculating the hinge loss. scope: The scope for the operations performed in computing the loss. Returns: diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py index 1417772e0496cb571488e5b30bd4f3fb1b591730..2a442a8fc85c8ab70dfa3b2183fc50f5c9a468e4 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py @@ -24,10 +24,8 @@ from tensorflow.contrib.framework.python.ops import arg_scope from tensorflow.contrib.losses.python.losses import loss_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -275,7 +273,6 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3) -@test_util.with_c_api class SparseSoftmaxCrossEntropyLossTest(test.TestCase): def testNoneWeightRaisesValueError(self): @@ -473,11 +470,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): labels = constant_op.constant([[0, 1], [2, 3]]) weights = constant_op.constant([1.2, 3.4, 5.6, 7.8]) - if ops._USE_C_API: - error_type = ValueError - else: - error_type = errors_impl.InvalidArgumentError - with self.assertRaises(error_type): + with self.assertRaises(ValueError): loss_ops.sparse_softmax_cross_entropy( logits, labels, weights=weights).eval() diff --git a/tensorflow/contrib/makefile/compile_nsync.sh b/tensorflow/contrib/makefile/compile_nsync.sh index e8c6edd7ba9aa6a45d956d1d5655b2809d8d2309..a28fc3a87f9503074806d780a11878a9274efc6f 100755 --- a/tensorflow/contrib/makefile/compile_nsync.sh +++ b/tensorflow/contrib/makefile/compile_nsync.sh @@ -270,7 +270,7 @@ for arch in $archs; do PLATFORM_LDFLAGS=-pthread MKDEP=${CC} -M -std=c++11 PLATFORM_C=../../platform/c++11/src/nsync_semaphore_mutex.cc \ - ../../platform/c++11/src/per_thread_waiter.cc \ + ../../platform/posix/src/per_thread_waiter.c \ ../../platform/c++11/src/yield.cc \ ../../platform/c++11/src/time_rep_timespec.cc \ ../../platform/c++11/src/nsync_panic.cc diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index eff9081e35c285027c764c5bdbaf14f78bc5f512..48953e2e3843ff92744514d28bd725cc0d72f3a8 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -27,9 +27,7 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once -# the archive has been propagated in mirror.bazel.build. -GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index d4c3f2eda8be0c70e961afe582983b9f73769c77..89db9ee2794ddf0a99951dca327e74c5d9694d23 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -300,7 +300,6 @@ tensorflow/core/kernels/spacetobatch_op.cc tensorflow/core/kernels/batchtospace_op.cc tensorflow/core/kernels/warn_about_ints.cc tensorflow/core/kernels/segment_reduction_ops.cc -tensorflow/core/kernels/batch_util.cc tensorflow/core/ops/audio_ops.cc tensorflow/core/kernels/decode_proto_op.cc tensorflow/core/kernels/encode_proto_op.cc diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index e050f3c8d4fc61adfdd3d15869e533ed2b51c4a8..4f2c82ca23011667662c74507fcbd99bcde4c7c0 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -77,7 +77,7 @@ py_test( py_test( name = "metric_ops_test", srcs = ["python/ops/metric_ops_test.py"], - shard_count = 8, + shard_count = 16, srcs_version = "PY2AND3", tags = ["noasan"], # times out b/63678675 deps = [ diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 00a933e5e0c537033573b225d43581f74557b240..a6be2084aae6bb05f958929b45977ed21b570603 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1544,7 +1544,7 @@ def precision_recall_at_equal_thresholds(labels, result: A named tuple (See PrecisionRecallData within the implementation of this function) with properties that are variables of shape `[num_thresholds]`. The names of the properties are tp, fp, tn, fn, - precision, recall, thresholds. + precision, recall, thresholds. Types are same as that of predictions. update_op: An op that accumulates values. Raises: @@ -1570,7 +1570,6 @@ def precision_recall_at_equal_thresholds(labels, check_ops.assert_type(labels, dtypes.bool) - dtype = predictions.dtype with variable_scope.variable_scope(name, 'precision_recall_at_equal_thresholds', (labels, predictions, weights)): @@ -1592,11 +1591,16 @@ def precision_recall_at_equal_thresholds(labels, predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - # We cast to float to ensure we have 0.0 or 1.0. - f_labels = math_ops.cast(labels, dtype) + # It's important we aggregate using float64 since we're accumulating a lot + # of 1.0's for the true/false labels, and accumulating to float32 will + # be quite inaccurate even with just a modest amount of values (~20M). + # We use float64 instead of integer primarily since GPU scatter kernel + # only support floats. + agg_dtype = dtypes.float64 - # Get weighted true/false labels. - true_labels = f_labels * weights + f_labels = math_ops.cast(labels, agg_dtype) + weights = math_ops.cast(weights, agg_dtype) + true_labels = f_labels * weights false_labels = (1.0 - f_labels) * weights # Flatten predictions and labels. @@ -1638,9 +1642,9 @@ def precision_recall_at_equal_thresholds(labels, with ops.name_scope('variables'): tp_buckets_v = metrics_impl.metric_variable( - [num_thresholds], dtype, name='tp_buckets') + [num_thresholds], agg_dtype, name='tp_buckets') fp_buckets_v = metrics_impl.metric_variable( - [num_thresholds], dtype, name='fp_buckets') + [num_thresholds], agg_dtype, name='fp_buckets') with ops.name_scope('update_op'): update_tp = state_ops.scatter_add( @@ -1660,18 +1664,21 @@ def precision_recall_at_equal_thresholds(labels, fn = tp[0] - tp # We use a minimum to prevent division by 0. - epsilon = 1e-7 + epsilon = ops.convert_to_tensor(1e-7, dtype=agg_dtype) precision = tp / math_ops.maximum(epsilon, tp + fp) recall = tp / math_ops.maximum(epsilon, tp + fn) + # Convert all tensors back to predictions' dtype (as per function contract). + out_dtype = predictions.dtype + _convert = lambda tensor: math_ops.cast(tensor, out_dtype) result = PrecisionRecallData( - tp=tp, - fp=fp, - tn=tn, - fn=fn, - precision=precision, - recall=recall, - thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds)) + tp=_convert(tp), + fp=_convert(fp), + tn=_convert(tn), + fn=_convert(fn), + precision=_convert(precision), + recall=_convert(recall), + thresholds=_convert(math_ops.lin_space(0.0, 1.0, num_thresholds))) update_op = control_flow_ops.group(update_tp, update_fp) return result, update_op @@ -2496,7 +2503,7 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name): name: An optional variable_scope name. Returns: - The recall at a the given `precision`. + The recall at a given `precision`. """ precisions = math_ops.div(tp, tp + fp + _EPSILON) tf_index = math_ops.argmin( diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 76420db8bda39435bcc2be2fd3d8c3467d6753e2..b13f08a37d9e856d56903324fc6e7cf1457bb191 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2333,47 +2333,24 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): np.random.seed(1) ops.reset_default_graph() - def _testResultsEqual(self, expected_dict, gotten_result): + def _testResultsEqual(self, expected_dict, gotten_result, eps=None): """Tests that 2 results (dicts) represent the same data. Args: expected_dict: A dictionary with keys that are the names of properties of PrecisionRecallData and whose values are lists of floats. gotten_result: A PrecisionRecallData object. + eps: Epsilon value to use for testing output values. If unspecified, use + default from assertAllClose. """ gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()} self.assertItemsEqual(list(expected_dict.keys()), list(gotten_dict.keys())) for key, expected_values in expected_dict.items(): - self.assertAllClose(expected_values, gotten_dict[key]) - - def _testCase(self, predictions, labels, expected_result, weights=None): - """Performs a test given a certain scenario of labels, predictions, weights. - - Args: - predictions: The predictions tensor. Of type float32. - labels: The labels tensor. Of type bool. - expected_result: The expected result (dict) that maps to tensors. - weights: Optional weights tensor. - """ - with self.test_session() as sess: - predictions_tensor = constant_op.constant( - predictions, dtype=dtypes_lib.float32) - labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) - weights_tensor = None - if weights: - weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32) - gotten_result, update_op = ( - metric_ops.precision_recall_at_equal_thresholds( - labels=labels_tensor, - predictions=predictions_tensor, - weights=weights_tensor, - num_thresholds=3)) - - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - - self._testResultsEqual(expected_result, gotten_result) + if eps is not None: + self.assertAllClose(expected_values, gotten_dict[key], atol=eps) + else: + self.assertAllClose(expected_values, gotten_dict[key]) def testVars(self): metric_ops.precision_recall_at_equal_thresholds( @@ -2414,6 +2391,78 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): for _ in range(3): self._testResultsEqual(initial_result, result) + def testLargeCase(self): + self.skipTest("Test consistently timing out") + shape = [32, 512, 256, 1] + predictions = random_ops.random_uniform( + shape, 0.0, 1.0, dtype=dtypes_lib.float32) + labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) + + result, update_op = metric_ops.precision_recall_at_equal_thresholds( + labels=labels, predictions=predictions, num_thresholds=201) + # Run many updates, enough to cause highly inaccurate values if the + # code used float32 for accumulation. + num_updates = 71 + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_updates): + sess.run(update_op) + + prdata = sess.run(result) + + # Since we use random values, we won't know the tp/fp/tn/fn values, but + # tp and fp at threshold 0 should be the total number of positive and + # negative labels, hence their sum should be total number of pixels. + expected_value = 1.0 * np.product(shape) * num_updates + got_value = prdata.tp[0] + prdata.fp[0] + # They should be at least within 1. + self.assertNear(got_value, expected_value, 1.0) + + def _testCase(self, + predictions, + labels, + expected_result, + dtype=dtypes_lib.float32, + eps=None, + weights=None): + """Performs a test given a certain scenario of labels, predictions, weights. + + Args: + predictions: The predictions tensor. Of type dtype. + labels: The labels tensor. Of type bool. + expected_result: The expected result (dict) that maps to tensors. + dtype: Data type to use for predictions and weights tensor. Default + is float32. + eps: Epsilon value to use for testing output values. If unspecified, use + default from assertAllClose. + weights: Optional weights tensor. + """ + with self.test_session() as sess: + predictions_tensor = constant_op.constant(predictions, dtype=dtype) + labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) + weights_tensor = None + if weights: + weights_tensor = constant_op.constant(weights, dtype=dtype) + gotten_result, update_op = ( + metric_ops.precision_recall_at_equal_thresholds( + labels=labels_tensor, + predictions=predictions_tensor, + weights=weights_tensor, + num_thresholds=3)) + self.assertEqual(gotten_result.tp.dtype, dtype) + self.assertEqual(gotten_result.fp.dtype, dtype) + self.assertEqual(gotten_result.tn.dtype, dtype) + self.assertEqual(gotten_result.fn.dtype, dtype) + self.assertEqual(gotten_result.precision.dtype, dtype) + self.assertEqual(gotten_result.recall.dtype, dtype) + self.assertEqual(gotten_result.thresholds.dtype, dtype) + + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + self._testResultsEqual(expected_result, gotten_result, eps=eps) + def testAllTruePositives(self): self._testCase( [[1]], [[True]], { @@ -2489,6 +2538,35 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): }, weights=[[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]]) + def testFloat64(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], { + 'tp': [4, 3, 0], + 'fp': [2, 0, 0], + 'tn': [0, 2, 2], + 'fn': [0, 1, 4], + 'precision': [2.0 / 3.0, 1.0, 0.0], + 'recall': [1.0, 0.75, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }, + dtype=dtypes_lib.float64) + + def testFloat16(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], { + 'tp': [4, 3, 0], + 'fp': [2, 0, 0], + 'tn': [0, 2, 2], + 'fn': [0, 1, 4], + 'precision': [2.0 / 3.0, 1.0, 0.0], + 'recall': [1.0, 0.75, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }, + dtype=dtypes_lib.float16, + eps=1e-3) + class StreamingSpecificityAtSensitivityTest(test.TestCase): @@ -7101,6 +7179,14 @@ class CohenKappaTest(test.TestCase): with self.assertRaises(ValueError): metrics.cohen_kappa(labels, invalid_predictions, 3) + def testConditionalPackingOptimization(self): + placeholder = array_ops.placeholder(dtypes_lib.float32, [None]) + values, update_op = metric_ops.streaming_concat(placeholder) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for feed in range(10): + sess.run(update_op, feed_dict={placeholder: [feed]}) + print(sess.run(values)) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/mixed_precision/BUILD b/tensorflow/contrib/mixed_precision/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3dfb95e0a006b13c23ea362bf622d80fd73703e6 --- /dev/null +++ b/tensorflow/contrib/mixed_precision/BUILD @@ -0,0 +1,32 @@ +# Mixed precision training optimizers + +package( + default_visibility = ["//tensorflow:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "mixed_precision", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/mixed_precision/python:loss_scale_manager", + "//tensorflow/contrib/mixed_precision/python:loss_scale_optimizer", + ], +) diff --git a/tensorflow/contrib/mixed_precision/__init__.py b/tensorflow/contrib/mixed_precision/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43e98cdda09222dc1334932265e516c6d460cdfc --- /dev/null +++ b/tensorflow/contrib/mixed_precision/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2018 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# mixed_precisiond under the License is mixed_precisiond on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library for mixed precision training.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.mixed_precision.python.loss_scale_manager import * +from tensorflow.contrib.mixed_precision.python.loss_scale_optimizer import * + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "LossScaleManager", + "FixedLossScaleManager", + "ExponentialUpdateLossScaleManager", + "LossScaleOptimizer", +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/mixed_precision/python/BUILD b/tensorflow/contrib/mixed_precision/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..1d769e16141e3eff664c449fc05b8441ee49d706 --- /dev/null +++ b/tensorflow/contrib/mixed_precision/python/BUILD @@ -0,0 +1,74 @@ +# Mixed precision training optimizers + +package( + default_visibility = ["//tensorflow:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "loss_scale_manager", + srcs = ["loss_scale_manager.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:state_ops", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "loss_scale_manager_test", + size = "small", + srcs = ["loss_scale_manager_test.py"], + deps = [ + ":loss_scale_manager", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_library( + name = "loss_scale_optimizer", + srcs = ["loss_scale_optimizer.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":loss_scale_manager", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) + +py_test( + name = "loss_scale_optimizer_test", + size = "small", + srcs = ["loss_scale_optimizer_test.py"], + deps = [ + ":loss_scale_optimizer", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..be7377b1519f3bdab8755411af3de7aa0c2dc9eb --- /dev/null +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py @@ -0,0 +1,200 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""LossScaleManager classes for mixed precision training.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import six + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_control_flow_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope + + +@six.add_metaclass(abc.ABCMeta) +class LossScaleManager(object): + """Abstract loss scale manager class. + + Loss scale managers with a different strategy should subclass this class. + Loss scaling is a process that: + + 1) Applies a multiplier on the loss before computing gradients, and + 2) Applies the reciprocal of the multiplier on the gradients before they are + applied on variables. + + This class is used together with + @{tf.contrib.mixed_precision.LossScaleOptimizer} for mixed precision training + (float32 variables and float16 ops) on Nvidia GPUs in order to achieve the + same model quality as single precision training, with the benefits of + potential higher throughput. + + See @{tf.contrib.mixed_precision.LossScaleOptimizer} for more details. + """ + + @abc.abstractmethod + def get_loss_scale(self): + """Returns the loss scale as a scalar `float32` tensor.""" + pass + + @abc.abstractmethod + def update_loss_scale(self, finite_grads): + """Updates loss scale based on if gradients are finite in current step. + + Args: + finite_grads: bool scalar tensor indicating if all gradients are + finite (i.e., not inf or nan). + + Returns: + An op, when executed updates the loss scale. If eager execution is + enabled, does not return anything. + """ + del finite_grads + return + + +class FixedLossScaleManager(LossScaleManager): + """Loss scale manager with a fixed loss scale. + + The loss scale is not updated for the lifetime of the class. + """ + + def __init__(self, loss_scale): + """Creates the fixed loss scale manager. + + Args: + loss_scale: A Python float. Its ideal value varies depending on models to + run. Choosing a too small loss_scale might affect model quality; a too + big loss_scale might cause inf or nan. There is no single right + loss_scale to apply. There is no harm choosing a relatively big number + as long as no nan or inf is encountered in training. + + Raises: + ValueError: If loss_scale is less than 1. + """ + if loss_scale < 1: + raise ValueError("loss scale must be at least 1.") + self._loss_scale = ops.convert_to_tensor(loss_scale, dtype=dtypes.float32) + + def get_loss_scale(self): + return self._loss_scale + + def update_loss_scale(self, finite_grads): + del finite_grads + return gen_control_flow_ops.no_op() + + +class ExponentialUpdateLossScaleManager(LossScaleManager): + """Loss scale manager uses an exponential update strategy. + + In general, the strategy increases loss scale by a greater-than-one factor + after encountering a consecutive series of steps with finite gradients; + Similarly, it decreases the loss scale by a factor when the accumulated number + of steps with non-finite (nan or inf) gradients are met. An update is not + applied if its result is less than 1 or overflows the float32 dynamic range. + + The number of finite and non-finite steps are cleared every time the loss + scale is changed. The condition to decrease the loss scale is looser than to + increase it since the former does not require the steps to be consecutive. + """ + + def __init__(self, + init_loss_scale, + incr_every_n_steps, + decr_every_n_nan_or_inf=2, + incr_ratio=2, + decr_ratio=0.8): + """Constructor of exponential-update loss scale manager. + + Args: + init_loss_scale: A Python float. The loss scale to use at the beginning. + incr_every_n_steps: Increases loss scale every n consecutive steps with + finite gradients. + decr_every_n_nan_or_inf: Decreases loss scale every n accumulated steps + with nan or inf gradients. + incr_ratio: The multiplier to use when increasing the loss scale. + decr_ratio: The less-than-one-multiplier to use when decreasing the loss + scale. + """ + self._incr_every_n_steps = incr_every_n_steps + self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf + self._incr_ratio = incr_ratio + self._decr_ratio = decr_ratio + self._loss_scale = variable_scope.variable( + name="loss_scale", + initial_value=ops.convert_to_tensor(init_loss_scale, dtypes.float32), + dtype=dtypes.float32, + trainable=False) + self._num_good_steps = variable_scope.variable( + name="good_steps", initial_value=0, dtype=dtypes.int32, trainable=False) + self._num_bad_steps = variable_scope.variable( + name="bad_steps", initial_value=0, dtype=dtypes.int32, trainable=False) + + def _reset_stats(self): + return control_flow_ops.group( + state_ops.assign(self._num_good_steps, 0), + state_ops.assign(self._num_bad_steps, 0)) + + def get_loss_scale(self): + """Returns the loss scale.""" + return self._loss_scale + + def update_loss_scale(self, finite_grads): + """Updates loss scale based on if gradients are finite in current step.""" + + def update_if_finite_grads(): + """Branch function when grads are all finite.""" + + def incr_loss_scale(): + new_loss_scale = control_flow_ops.cond( + gen_math_ops.is_finite(self._loss_scale * self._incr_ratio), + lambda: self._loss_scale * self._incr_ratio, + lambda: self._loss_scale) + update_op = state_ops.assign(self._loss_scale, new_loss_scale) + # When loss_scale is updated, both good and bad steps are reset. + return control_flow_ops.group(update_op, self._reset_stats()) + + return control_flow_ops.cond( + self._num_good_steps + 1 >= self._incr_every_n_steps, + incr_loss_scale, + lambda: state_ops.assign_add(self._num_good_steps, 1).op) + + def update_if_not_finite_grads(): + """Branch function when any grad is not finite.""" + + def decr_loss_scale(): + update_op = state_ops.assign( + self._loss_scale, + gen_math_ops.maximum(1., self._loss_scale * self._decr_ratio)) + # When loss_scale is updated, both good and bad steps are reset. + return control_flow_ops.group(update_op, self._reset_stats()) + + def just_update_steps(): + # When bad_steps is incremented, good_step is reset. + return control_flow_ops.group( + state_ops.assign_add(self._num_bad_steps, 1), + state_ops.assign(self._num_good_steps, 0)) + + return control_flow_ops.cond( + self._num_bad_steps + 1 >= self._decr_every_n_nan_or_inf, + decr_loss_scale, just_update_steps) + + return control_flow_ops.cond(finite_grads, update_if_finite_grads, + update_if_not_finite_grads) diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py new file mode 100644 index 0000000000000000000000000000000000000000..480f5f6eaf493c5c87c27cc9f8e510ea9c085a72 --- /dev/null +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py @@ -0,0 +1,182 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LossScaleManager classes..""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.mixed_precision.python import loss_scale_manager as lsm_lib +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def _GetExampleIter(inputs): + dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + return dataset.make_one_shot_iterator() + + +class FixedLossScaleManagerTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_basic(self): + itr = _GetExampleIter([True] * 10 + [False] * 10) + + loss_scale = 1000 + lsm = lsm_lib.FixedLossScaleManager(loss_scale) + update_fn = lambda: lsm.update_loss_scale(itr.get_next()) + + self.evaluate(variables.global_variables_initializer()) + if not context.executing_eagerly(): + update_op = update_fn() + for _ in range(10): + if context.executing_eagerly(): + update_fn() + else: + self.evaluate(update_op) + self.assertEqual(loss_scale, self.evaluate(lsm.get_loss_scale())) + + +class ExponentialUpdateLossScaleManagerTest(test.TestCase): + + def _test_helper(self, + inputs, + expected_outputs, + init_loss_scale=1, + incr_every_n_step=2, + decr_every_n_nan_or_inf=2): + ratio = 2 + lsm = lsm_lib.ExponentialUpdateLossScaleManager( + init_loss_scale=init_loss_scale, + incr_every_n_steps=incr_every_n_step, + decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, + incr_ratio=ratio, + decr_ratio=1. / ratio) + itr = _GetExampleIter(inputs) + update_fn = lambda: lsm.update_loss_scale(itr.get_next()) + + self.evaluate(variables.global_variables_initializer()) + actual_outputs = [] + + if not context.executing_eagerly(): + update_op = update_fn() + for _ in range(len(inputs)): + if context.executing_eagerly(): + update_fn() + else: + self.evaluate(update_op) + actual_outputs.append(self.evaluate(lsm.get_loss_scale())) + self.assertEqual(actual_outputs, expected_outputs) + + @test_util.run_in_graph_and_eager_modes() + def test_increase_every_n_steps(self): + inputs = [True] * 6 + expected_outputs = [1, 2, 2, 4, 4, 8] + self._test_helper(inputs, expected_outputs) + + @test_util.run_in_graph_and_eager_modes() + def test_keep_increasing_until_capped(self): + init_loss_scale = np.finfo(np.float32).max / 4 + 10 + max_float = np.finfo(np.float32).max + + inputs = [True] * 6 + # Output is capped the 2nd time it doubles. + expected_outputs = [ + init_loss_scale, init_loss_scale * 2, init_loss_scale * 2, max_float, + max_float, max_float + ] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @test_util.run_in_graph_and_eager_modes() + def test_decrease_every_n_steps(self): + inputs = [False] * 6 + init_loss_scale = 1024 + expected_outputs = [1024, 512, 512, 256, 256, 128] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @test_util.run_in_graph_and_eager_modes() + def test_keep_decreasing_until_one(self): + inputs = [False] * 10 + init_loss_scale = 16 + expected_outputs = [16, 8, 8, 4, 4, 2, 2, 1, 1, 1] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @test_util.run_in_graph_and_eager_modes() + def test_incr_bad_step_clear_good_step(self): + inputs = [True, True, True, False, True] + expected_outputs = [1, 2, 2, 2, 2] + self._test_helper(inputs, expected_outputs) + + @test_util.run_in_graph_and_eager_modes() + def test_incr_good_step_does_not_clear_bad_step(self): + inputs = [True, True, True, False, True, False] + expected_outputs = [1, 2, 2, 2, 2, 1] + self._test_helper(inputs, expected_outputs) + + @test_util.run_in_graph_and_eager_modes() + def test_trigger_loss_scale_update_each_step(self): + """Test when incr_every_n_step and decr_every_n_nan_or_inf is 1.""" + init_loss_scale = 1 + incr_every_n_step = 1 + decr_every_n_nan_or_inf = 1 + + inputs = [True] * 3 + [False, True, True] + expected_outputs = [2, 4, 8, 4, 8, 16] + + self._test_helper(inputs, expected_outputs, init_loss_scale, + incr_every_n_step, decr_every_n_nan_or_inf) + + @test_util.run_in_graph_and_eager_modes() + def test_alternating_good_and_bad_gradients_trigger_each_step(self): + init_loss_scale = 1 + incr_every_n_step = 1 + decr_every_n_nan_or_inf = 1 + + inputs = [True, False] * 4 + [True] + expected_outputs = [2, 1, 2, 1, 2, 1, 2, 1, 2] + self._test_helper(inputs, expected_outputs, init_loss_scale, + incr_every_n_step, decr_every_n_nan_or_inf) + + @test_util.run_in_graph_and_eager_modes() + def test_alternating_good_and_bad_gradients_trigger_incr_every_2steps(self): + init_loss_scale = 32 + incr_every_n_step = 2 + decr_every_n_nan_or_inf = 1 + + inputs = [True, False] * 3 + [True] + expected_outputs = [32, 16, 16, 8, 8, 4, 4] + self._test_helper(inputs, expected_outputs, init_loss_scale, + incr_every_n_step, decr_every_n_nan_or_inf) + + @test_util.run_in_graph_and_eager_modes() + def test_random_mix_good_and_bad_gradients(self): + init_loss_scale = 4 + inputs = [ + False, False, True, True, True, False, True, False, True, True, True, + False + ] + expected_outputs = [4, 2, 2, 4, 4, 4, 4, 2, 2, 4, 4, 4] + self._test_helper(inputs, expected_outputs, init_loss_scale) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ef34f7bf7bf3eba047b50ce8abf883b0ed741a63 --- /dev/null +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py @@ -0,0 +1,172 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loss scaling optimizer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_control_flow_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import optimizer + + +class LossScaleOptimizer(optimizer.Optimizer): + # TODO(jamesqin): move mixed precision training explanation to __init__ + # docstring. + """An optimizer that applies loss scaling in backprop. + + This class is useful for "mixed precision training" on GPUs (or other + potential accelerators), an approach to improve compute throughput without + compromising model quality. + + The canonical way to perform mixed precision training is the following: + * Model variables are kept in high precision (e.g. float32). + * Computations are done in lower precision (e.g. float16), which enjoys + performance speedup by virtue of hardware support. Variables are casted to + lower precision before they're used. + * Final gradients are casted back to high precision dtype, then used to update + variables. + + The side-effect of performing computation in lower precision, is that it comes + with smaller numerical range. During backproping, small gradients might + underflow in the reduced numerical range, causing a model to converge at + suboptimal level. + + To prevent underflow, this optimizer multiplies the loss by a factor before + backprop starts. Consequently, the gradients are linearly scaled up by the + same factor, thus not falling into the underflow zone. After that, to perserve + the correctness of backprop, the gradients are down-scaled by the same factor, + casted to the (higher) variable precision, then applied on the variables. + + See [Nvidia's manual on mixed precision training]( + https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) + for more details. + + To use loss scale optimizer, one only needs choose a loss scale strategy and + wrap a regular optimizer. See examples below. + + ``` + loss = loss_fn() + opt = tf.AdamOptimizer(learning_rate=...) + + # Choose a loss scale manager which decides how to pick the right loss scale + # throughout the training process. + loss_scale_manger = tf.contrib.mixed_precision.FixedLossScaleManager(5000) + + # Wraps the original optimizer in a LossScaleOptimizer. + loss_scale_optimizer = LossScaleOptimizer(opt, loss_scale_manager) + + # Call minimize() on the loss scale optimizer. + train_op = loss_scale_optimizer.minimize(loss) + ``` + + If gradients clipping is applied, one can call + `optimizer.compute_gradients()` and `optimizer.apply_gradients()` + seperately. + + Notice the following way of using LossScaleOptimizer is not intended. Always + use `loss_scale_optimizer.compute_gradients()` to compute gradients instead of + `tf.gradients()` if doing mixed precision training. + + ``` + # The following is a wrong way to use LossScaleOptimizer along with + # tf.gradients(). + + # Always use loss_scale_optimizer.compute_gradients() to compute grads, or + # loss scale is not correctly applied. + grads = tf.gradients(loss, ...) + + # Do some custom grad clipping. + grads = clip_grads(grads, ...) + + loss_scale_optimizer.apply(grads_and_vars) + ``` + """ + + def __init__(self, opt, loss_scale_manager): + """Construct a loss scaling optimizer. + + Args: + opt: The actual optimizer that will be used to compute and apply the + gradients. Must be an implementation of the @{tf.train.Optimizer} + interface. + loss_scale_manager: A LossScaleManager object. + """ + self._opt = opt + self._loss_scale_manager = loss_scale_manager + + def compute_gradients(self, + loss, + var_list=None, + gate_gradients=optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None): + """Compute gradients. See base class @{tf.train.Optimizer}.""" + loss_scale = self._loss_scale_manager.get_loss_scale() + if context.executing_eagerly(): + + def scaled_loss(): + loss_val = loss() + return loss_val * math_ops.cast(loss_scale, loss_val.dtype.base_dtype) + else: + if callable(loss): + loss_val = loss() + else: + loss_val = loss + scaled_loss = loss_val * math_ops.cast(loss_scale, + loss_val.dtype.base_dtype) + grads_and_vars = self._opt.compute_gradients( + scaled_loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + return self._down_scale(grads_and_vars, loss_scale) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + """Apply gradients. See base class @{tf.train.Optimizer}.""" + grads = [g for (g, _) in grads_and_vars] + + is_finite_grad = [] + for g in grads: + is_finite_grad.append(math_ops.reduce_all(gen_math_ops.is_finite(g))) + is_overall_finite = math_ops.reduce_all(is_finite_grad) + + # Only update gradients when all grads are finite. + def true_apply_gradients_fn(): + return self._opt.apply_gradients(grads_and_vars, global_step, name) + + update_vars = control_flow_ops.cond( + is_overall_finite, true_apply_gradients_fn, gen_control_flow_ops.no_op) + # Potentially adjust gradient scale in case of finite gradients. + return control_flow_ops.group( + update_vars, + self._loss_scale_manager.update_loss_scale(is_overall_finite)) + + def _down_scale(self, grads_vars, loss_scale): + # Down scale grads by the loss_scale. + gv = [] + inv_loss_scale = gen_math_ops.reciprocal(loss_scale) + for g, v in grads_vars: + if g is not None: + gv.append((g * math_ops.cast(inv_loss_scale, g.dtype.base_dtype), v)) + else: + gv.append((g, v)) + return gv diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dded61ccd58eb79b338d7264e8a057c9456c8695 --- /dev/null +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py @@ -0,0 +1,216 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LossScaleOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.mixed_precision.python import loss_scale_manager as lsm_lib +from tensorflow.contrib.mixed_precision.python import loss_scale_optimizer as lso +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent as gd + + +class LossScaleOptimizerTest(test.TestCase): + + def _build_graph(self, lr, init_val, loss_scale_opt_fn=None): + x = variable_scope.get_variable( + "x", initializer=init_val, dtype=dtypes.float32) + c1 = constant_op.constant(1e4, dtype=dtypes.float16) + c2 = constant_op.constant(1e-4, dtype=dtypes.float16) + c3 = constant_op.constant(1e-4, dtype=dtypes.float16) + if context.executing_eagerly(): + loss = lambda: math_ops.cast(x, dtypes.float16) * c1 * c2 * c3 + else: + loss = math_ops.cast(x, dtypes.float16) * c1 * c2 * c3 + + opt = gd.GradientDescentOptimizer(lr) + if loss_scale_opt_fn: + opt = loss_scale_opt_fn(opt) + return x, loss, opt + + @test_util.run_in_graph_and_eager_modes() + def test_float16_underflow_without_loss_scale(self): + lr = 1 + init_val = 1. + x, loss, opt = self._build_graph(lr, init_val) + + self.evaluate(variables.global_variables_initializer()) + self.evaluate(opt.minimize(loss, var_list=[x])) + + # Symbolic grad is c1 * c2 * c3 = 1e-4 and actual grad is 0, since in + # backprop, c2 * c3 underflows in fp16 range. So variable isn't updated. + expected_update = 0 + symbolic_update = 1e-4 * lr + self.assertAllClose( + init_val - expected_update, + self.evaluate(x), + rtol=0, + atol=min(symbolic_update, 1e-6)) + + @test_util.run_in_graph_and_eager_modes() + def test_float16_with_loss_scale(self): + lr = 1. + init_val = 1. + + def loss_scale_opt_fn(opt): + return lso.LossScaleOptimizer(opt, lsm_lib.FixedLossScaleManager(1e4)) + + x, loss, opt = self._build_graph(lr, init_val, loss_scale_opt_fn) + + self.evaluate(variables.global_variables_initializer()) + self.evaluate(opt.minimize(loss, var_list=[x])) + + # Symbolic grad is c1 * c2 * c3 = 1e-4 and actual grad is the same, due to + # up-scaled loss before backprop starts. + expected_update = 1.e-4 * lr + self.assertAllClose( + init_val - expected_update, + self.evaluate(x), + rtol=0, + atol=min(expected_update, 1e-6)) + + @test_util.run_in_graph_and_eager_modes() + def test_compute_gradients_with_loss_scale(self): + lr = 1 + init_val = 1. + + def loss_scale_opt_fn(opt): + return lso.LossScaleOptimizer(opt, lsm_lib.FixedLossScaleManager(1e4)) + + x, loss, opt = self._build_graph(lr, init_val, loss_scale_opt_fn) + grads_and_vars = opt.compute_gradients(loss, var_list=[x]) + + self.assertEqual(len(grads_and_vars), 1) + + self.evaluate(variables.global_variables_initializer()) + g_v = self.evaluate(grads_and_vars[0][0]) + self.assertAllClose(g_v, 1e-4) + self.assertIs(grads_and_vars[0][1], x) + # Gradients aren't applied. + self.assertAllClose(init_val, self.evaluate(x), rtol=0, atol=1e-6) + + @test_util.run_in_graph_and_eager_modes() + def test_compute_gradients_without_loss_scale(self): + lr = 1 + init_val = 1. + x, loss, opt = self._build_graph(lr, init_val) + grads_and_vars = opt.compute_gradients(loss, var_list=[x]) + + self.assertEqual(len(grads_and_vars), 1) + self.evaluate(variables.global_variables_initializer()) + g_v = self.evaluate(grads_and_vars[0][0]) + self.assertAllClose(g_v, 0) + + @test_util.run_in_graph_and_eager_modes() + def test_apply_gradients(self): + + x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) + dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) + itr = dataset.make_one_shot_iterator() + + lr = 1 + opt = gd.GradientDescentOptimizer(lr) + lsm = lsm_lib.FixedLossScaleManager(1.e4) + opt = lso.LossScaleOptimizer(opt, lsm) + train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)]) + if not context.executing_eagerly(): + train_op = train_fn() + + expected_output = [1, 1, 1 - 0.1] + actual_output = [] + + self.evaluate(variables.global_variables_initializer()) + for _ in range(3): + # nan or inf is not applied. + if context.executing_eagerly(): + train_fn() + else: + self.evaluate(train_op) + actual_output.append(self.evaluate(x)) + self.assertAllClose(expected_output, actual_output) + + @test_util.run_in_graph_and_eager_modes() + def test_apply_gradients_loss_scale_is_updated(self): + + class SimpleLossScaleManager(lsm_lib.LossScaleManager): + """A simple loss scale manager for easier testing. + + It increments loss scale by 1 if grads are finite, and decreases loss + scale by 1 if otherwise. + """ + + def __init__(self, loss_scale): + self._loss_scale = variable_scope.variable( + name="loss_scale", + initial_value=loss_scale, + dtype=dtypes.float32, + trainable=False) + + def get_loss_scale(self): + return self._loss_scale + + def update_loss_scale(self, if_finite_grads): + return control_flow_ops.cond( + if_finite_grads, lambda: state_ops.assign_add(self._loss_scale, 1), + lambda: state_ops.assign_sub(self._loss_scale, 1)) + + x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) + dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) + itr = dataset.make_one_shot_iterator() + + lr = 1 + init_loss_scale = 8 + opt = gd.GradientDescentOptimizer(lr) + lsm = SimpleLossScaleManager(init_loss_scale) + opt = lso.LossScaleOptimizer(opt, lsm) + train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)]) + if not context.executing_eagerly(): + train_op = train_fn() + + self.evaluate(variables.global_variables_initializer()) + + expected_loss_scale = [ + init_loss_scale - 1, init_loss_scale - 2, init_loss_scale - 2 + 1 + ] + expected_output = [1, 1, 1 - 0.1] + actual_output = [] + for i in range(3): + # nan or inf is not applied. + if context.executing_eagerly(): + train_fn() + else: + self.evaluate(train_op) + actual_output.append(self.evaluate(x)) + self.assertAllClose(expected_loss_scale[i], + self.evaluate(lsm._loss_scale)) + self.assertAllClose(expected_output, actual_output) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/mpi_collectives/kernels/ring.h b/tensorflow/contrib/mpi_collectives/kernels/ring.h index 1d56d588bc49eda542303ae6ebb19602352ae01d..c001615d3ffbdf04194cf8fd1fd242542bf8f89d 100644 --- a/tensorflow/contrib/mpi_collectives/kernels/ring.h +++ b/tensorflow/contrib/mpi_collectives/kernels/ring.h @@ -129,7 +129,7 @@ cudaStream_t CudaStreamForMPI(); * has the fully accumulated Segment 1; and so on. The scatter-reduce is * complete. * - * Next, the allgather distributes these fully accumululated chunks across all + * Next, the allgather distributes these fully accumulated chunks across all * nodes. Communication proceeds in the same ring, once again in N-1 steps. At * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i). * For example, at the first iteration, the following transfers will occur: diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc index 4d8d922cb42d2974dab32cf4562bee3993bef098..5144f7c38c8650ebfced1dfcc9378263ebaad8c0 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc @@ -171,8 +171,7 @@ class NcclManagerTest : public ::testing::Test { private: static Allocator* GpuAllocator(BaseGPUDevice* device) { - return device->GetStepAllocator(AllocatorAttributes(), - nullptr /* step_resource_manager */); + return device->GetAllocator(AllocatorAttributes()); } static se::DeviceMemory AsDeviceMemory(const Scalar* cuda_memory) { diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py index bc92a7006f1a0a56adafc486a75afa94e965cb2c..915e6504e1e59ff247a2715820bc31a4d4cc1944 100644 --- a/tensorflow/contrib/opt/python/training/adamax_test.py +++ b/tensorflow/contrib/opt/python/training/adamax_test.py @@ -198,11 +198,11 @@ class AdaMaxOptimizerTest(test.TestCase): self.assertTrue(beta1_power is not None) self.assertIn(beta1_power, opt_variables) - with ops.Graph().as_default(): - # Shouldn't return non-slot variables from other graphs. - self.assertEqual(0, len(opt.variables())) - if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) @@ -224,8 +224,10 @@ class AdaMaxOptimizerTest(test.TestCase): var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0), + rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1), + rtol=1e-2) if use_resource: self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name) diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py index a7c97a1da2baf29914337094c6153447c997af08..b6b10e500b6af80ab61cbf74077ea8e70800662f 100644 --- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py @@ -62,7 +62,7 @@ class ModelAverageCustomGetter(object): """ def __init__(self, worker_device): - """Create a new `ElasticAverageCustomGetter`. + """Create a new `ModelAverageCustomGetter`. Args: worker_device: String. Name of the `worker` job. diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 9e2858d00ff192e56680b288651975410c63c539..64b95786b5c7a71ee514201d8eb60c26975938b5 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -31,19 +31,20 @@ from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope -from tensorflow.python.training import checkpointable -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils class NonLayerCheckpointable(checkpointable.Checkpointable): @@ -139,8 +140,9 @@ class CheckpointingTests(test.TestCase): self.evaluate(checkpointable_utils.gather_initializers( root_checkpointable)) self.evaluate(train_op) - named_variables, serialized_graph = ( - checkpointable_utils._serialize_object_graph(root_checkpointable)) + named_variables, serialized_graph, _ = ( + checkpointable_utils._serialize_object_graph( + root_checkpointable, saveables_cache=None)) expected_checkpoint_names = ( # Created in the root node, so no prefix. "optimizer_step", @@ -163,24 +165,29 @@ class CheckpointingTests(test.TestCase): suffix = "/.ATTRIBUTES/VARIABLE_VALUE" expected_checkpoint_names = [ name + suffix for name in expected_checkpoint_names] + # The Dense layers also save get_config() JSON + expected_checkpoint_names.extend( + ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON", + "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"]) + named_variables = {v.name: v for v in named_variables} six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys()) # Check that we've mapped to the right variable objects (not exhaustive) self.assertEqual( - "global_step:0", - named_variables["optimizer_step" + suffix].name) + "global_step", + named_variables["optimizer_step" + suffix].full_name) self.assertEqual( - "my_model/dense_1/kernel:0", - named_variables["model/_second/kernel" + suffix].name) + "my_model/dense_1/kernel", + named_variables["model/_second/kernel" + suffix].full_name) self.assertEqual( - "my_model/dense/kernel:0", - named_variables["model/_named_dense/kernel" + suffix].name) + "my_model/dense/kernel", + named_variables["model/_named_dense/kernel" + suffix].full_name) self.assertEqual( - "beta1_power:0", - named_variables["optimizer/beta1_power" + suffix].name) + "beta1_power", + named_variables["optimizer/beta1_power" + suffix].full_name) self.assertEqual( - "beta2_power:0", - named_variables["optimizer/beta2_power" + suffix].name) + "beta2_power", + named_variables["optimizer/beta2_power" + suffix].full_name) # Spot check the generated protocol buffers. self.assertEqual("optimizer", serialized_graph.nodes[0].children[1].local_name) @@ -205,7 +212,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual( "my_model/dense/kernel/Adam:0", optimizer.get_slot( - var=named_variables["model/_named_dense/kernel" + suffix], + var=model._named_dense.kernel, name="m").name) self.assertEqual( "model/_named_dense/kernel" + suffix, @@ -417,16 +424,6 @@ class CheckpointingTests(test.TestCase): self.evaluate(root.save_counter)) # pylint: enable=cell-var-from-loop - def _get_checkpoint_name(self, name): - root = checkpointable.Checkpointable() - checkpointable_utils.add_variable( - root, name=name, shape=[1, 2], dtype=dtypes.float64) - named_variables, _ = checkpointable_utils._serialize_object_graph(root) - checkpoint_name, = named_variables.keys() - with ops.name_scope("root/" + checkpoint_name): - pass # Make sure we can use this as an op name if we prefix it. - return checkpoint_name - def testAnonymousVarsInInit(self): class Model(training.Model): @@ -617,6 +614,49 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual(3., self.evaluate(beta1_power)) +class TemplateTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_checkpointable_save_restore(self): + + def _templated(): + v = variable_scope.get_variable( + "v", shape=[1], initializer=init_ops.zeros_initializer(), + use_resource=True) + v2 = variable_scope.get_variable( + "v2", shape=[1], initializer=init_ops.zeros_initializer(), + use_resource=True) + return v, v + 1., v2 + + save_template = template.make_template("s1", _templated) + v1_save, _, v2_save = save_template() + optimizer = adam.AdamOptimizer(0.0) + save_root = checkpointable_utils.Checkpoint( + my_template=save_template, optimizer=optimizer) + optimizer.minimize(v1_save.read_value) + self.evaluate([v.initializer for v in optimizer.variables()]) + self.evaluate(v1_save.assign([12.])) + self.evaluate(v2_save.assign([14.])) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = save_root.save(checkpoint_prefix) + + load_template = template.make_template("s2", _templated) + load_optimizer = adam.AdamOptimizer(0.0) + load_root = checkpointable_utils.Checkpoint( + my_template=load_template, optimizer=load_optimizer) + status = load_root.restore(save_path) + var, var_plus_one, var2 = load_template() + load_optimizer.minimize(var.read_value) + self.assertEqual(2, len(load_template._checkpoint_dependencies)) + self.assertEqual("v", load_template._checkpoint_dependencies[0].name) + self.assertEqual("v2", load_template._checkpoint_dependencies[1].name) + status.assert_consumed().run_restore_ops() + self.assertAllEqual([12.], self.evaluate(var)) + self.assertAllEqual([13.], self.evaluate(var_plus_one)) + self.assertAllEqual([14.], self.evaluate(var2)) + + class CheckpointCompatibilityTests(test.TestCase): def _initialized_model(self): @@ -682,12 +722,22 @@ class CheckpointCompatibilityTests(test.TestCase): with self.assertRaises(AssertionError): self._check_sentinels(root) object_saver = checkpointable_utils.CheckpointableSaver(root) + self._set_sentinels(root) status = object_saver.restore(save_path) - with self.assertRaises(AssertionError): - status.assert_consumed() + if context.executing_eagerly(): + self._check_sentinels(root) + if context.executing_eagerly(): + with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"): + status.assert_consumed() + else: + # When graph building, we haven't read any keys, so we don't know + # whether the restore will be complete. + with self.assertRaisesRegexp(AssertionError, "not restored"): + status.assert_consumed() status.run_restore_ops() self._check_sentinels(root) self._set_sentinels(root) + status = object_saver.restore(save_path) status.initialize_or_restore() self._check_sentinels(root) diff --git a/tensorflow/contrib/optimizer_v2/momentum_test.py b/tensorflow/contrib/optimizer_v2/momentum_test.py index 26724f66c2a1db1d01577b31b739af18f51d3976..24cdab462665adc6297b0e0821455a545c3880af 100644 --- a/tensorflow/contrib/optimizer_v2/momentum_test.py +++ b/tensorflow/contrib/optimizer_v2/momentum_test.py @@ -134,7 +134,6 @@ class MomentumOptimizerTest(test.TestCase): with context.eager_mode(): self.doTestBasic(use_resource=True, use_callable_params=True) - @test_util.run_in_graph_and_eager_modes(reset_test=True) def testVariablesAcrossGraphs(self): optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5) with ops.Graph().as_default(): @@ -142,10 +141,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var0") var1 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var1") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var0 + var1) - else: - loss = math_ops.reduce_sum(var0 + var1) + loss = math_ops.reduce_sum(var0 + var1) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var0") @@ -157,10 +153,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var2") var3 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var3") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var2 + var3) - else: - loss = math_ops.reduce_sum(var2 + var3) + loss = math_ops.reduce_sum(var2 + var3) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var2") diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 46bfbb729fa9cdfc98f4228f516a7c5c42f23059..f537318b32986c941b6c41eb363929e906027dd7 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -33,10 +33,10 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import checkpointable from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -360,7 +360,16 @@ class _OptimizerV2State(object): """ slot_variable = self.get_slot(var=variable, name=slot_name) if (slot_variable is None and context.executing_eagerly() and - slot_variable_position.is_simple_variable()): + slot_variable_position.is_simple_variable() + # Defer slot variable creation if there is an active variable creator + # scope. Generally we'd like to eagerly create/restore slot variables + # when possible, but this may mean that scopes intended to catch + # `variable` also catch its eagerly created slot variable + # unintentionally (specifically make_template would add a dependency on + # a slot variable if not for this case). Deferring is mostly harmless + # (aside from double initialization), and makes variable creator scopes + # behave the same way they do when graph building. + and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access initializer = checkpointable.CheckpointInitialValue( checkpoint_position=slot_variable_position) slot_variable = self.create_slot( diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD index 6ca7fe8b6e59b0dc24be76262d4f54f387e53e48..976b312e8345a801ad07f622b6117b88af2cf603 100644 --- a/tensorflow/contrib/periodic_resample/BUILD +++ b/tensorflow/contrib/periodic_resample/BUILD @@ -6,12 +6,13 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "py_test", + "tf_cc_test", "tf_gen_op_libs", "tf_custom_op_library", "tf_custom_op_py_library", "tf_gen_op_wrapper_py", ) +load("//tensorflow:tensorflow.bzl", "py_test") cc_library( name = "all_ops", @@ -84,6 +85,20 @@ py_test( ":init_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradient_checker", + ], +) + +tf_cc_test( + name = "periodic_resample_op_cc_test", + size = "small", + srcs = [ + "ops/array_ops_test.cc", + ], + deps = [ + ":all_ops", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc index e18923c8aae74c66ce78f98eb5e615e99463af74..514689cf4543cd08632bd0321a78fa933c456467 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc @@ -22,4 +22,9 @@ namespace tensorflow { REGISTER_KERNEL_BUILDER(Name("PeriodicResample").Device(DEVICE_CPU), PeriodicResampleOp); + +REGISTER_KERNEL_BUILDER(Name("PeriodicResampleOpGrad") + .Device(DEVICE_CPU), + PeriodicResampleOpGrad); + } // namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h index 3ab588c45881c8f93b4c1bcdf7ccde39086a1ed7..42fba81a5cb9490c093062048f269704a110756a 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h @@ -25,92 +25,202 @@ #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/work_sharder.h" namespace { -template -IndexT compute_input_index( - IndexVecT* target_dimensions, const IndexT& output_index, - const IndexVecT& original_dimensions, const int& adjustable_dimension, - const std::vector& dimension_ceiling, - const std::vector& cumulative_dimensions, IndexT* result, - std::vector* output_indices, const int& rank) { - *result = 0; - output_indices->clear(); +// Computes input tensor index for given output index during forward +// propagation through periodic_resample operation. +class InputIndexer { + public: + InputIndexer(const std::vector& output_dimensions, + const tensorflow::TensorShape& input_shape, + int adjustable_dimension) + : output_dimensions_(output_dimensions), + adjustable_dimension_(adjustable_dimension), + rank_(input_shape.dims()), + linear_output_index_(0), + linear_input_index_(0), + adjustable_dimension_carriage_sum_(0) { + auto input_dimensions = TensorShapeToVector(input_shape); + // factors by which input_dimensions increases/decreases w.r.t. + // output_dimensions + dimension_ceiling_ = + ComputeDimensionCeiling(output_dimensions, input_dimensions); + cumulative_dimensions_ = ComputeCumulativeDimensions(); + + output_indices_.resize(output_dimensions_.size()); + input_indices_.resize(output_dimensions_.size()); + + // Compute index_factors + index_factors_.resize(rank_); + tensorflow::int64 last_index_factor = 1; + for (auto r = rank_ - 1; r >= 0; --r) { + index_factors_[r] = last_index_factor; + last_index_factor *= input_dimensions[r]; + } + } + + tensorflow::int64 linear_input_index() const { return linear_input_index_; } + + void MoveToOutputIndex(tensorflow::int64 output_index); + void IncrementOutputIndex(); + + private: + void RecomputeInputAdjustableDimensionIndex() { + tensorflow::int64 index = adjustable_dimension_carriage_sum_; + index *= output_dimensions_[adjustable_dimension_]; + index += output_indices_[adjustable_dimension_]; + input_indices_[adjustable_dimension_] = index; + } + + std::vector TensorShapeToVector( + const tensorflow::TensorShape& tensor_shape); + + std::vector ComputeDimensionCeiling( + const std::vector& output_dimensions, + const std::vector& input_dimensions); + + std::vector ComputeCumulativeDimensions(); + + const std::vector output_dimensions_; + std::vector dimension_ceiling_; + std::vector index_factors_; + std::vector cumulative_dimensions_; + std::vector output_indices_; + std::vector input_indices_; + + const int adjustable_dimension_; + const int rank_; + tensorflow::int64 linear_output_index_; + tensorflow::int64 linear_input_index_; + tensorflow::int64 adjustable_dimension_carriage_sum_; +}; + +void InputIndexer::MoveToOutputIndex(tensorflow::int64 output_index) { + linear_output_index_ = output_index; + linear_input_index_ = 0; // un-rasterize the output index auto last_reduced_i = output_index; - for (auto r = rank - 1; r >= 0; --r) { - (*output_indices)[r] = last_reduced_i % (*target_dimensions)[r]; + for (auto r = rank_ - 1; r >= 0; --r) { + output_indices_[r] = last_reduced_i % output_dimensions_[r]; last_reduced_i = - (last_reduced_i - (*output_indices)[r]) / (*target_dimensions)[r]; + (last_reduced_i - output_indices_[r]) / output_dimensions_[r]; } + tensorflow::int64 carriage_sum = 0; + for (int qi = 0; qi < rank_; ++qi) { + if (qi == adjustable_dimension_) continue; + carriage_sum += cumulative_dimensions_[qi] * + (output_indices_[qi] % dimension_ceiling_[qi]); + } + adjustable_dimension_carriage_sum_ = carriage_sum; + // rasterize the input index - IndexT last_index_factor = 1; - for (auto r = rank - 1; r >= 0; --r) { - IndexT index = 0; - if (r != adjustable_dimension) - index = (*output_indices)[r] / dimension_ceiling[r]; - else { - for (int qi = 0; qi < rank; ++qi) { - if (qi == adjustable_dimension) continue; - index += cumulative_dimensions[qi] * - ((*output_indices)[qi] % dimension_ceiling[qi]); - } - index *= (*target_dimensions)[adjustable_dimension]; - index += (*output_indices)[r]; + for (auto r = rank_ - 1; r >= 0; --r) { + if (r != adjustable_dimension_) { + input_indices_[r] = output_indices_[r] / dimension_ceiling_[r]; + } else { + RecomputeInputAdjustableDimensionIndex(); } - *result += last_index_factor * index; - last_index_factor *= original_dimensions[r]; } + for (auto r = rank_ - 1; r >= 0; --r) { + linear_input_index_ += index_factors_[r] * input_indices_[r]; + } +} + +void InputIndexer::IncrementOutputIndex() { + linear_output_index_++; + for (auto r = rank_ - 1; r >= 0; --r) { + auto old_carriage_sum_increment = + cumulative_dimensions_[r] * + (output_indices_[r] % dimension_ceiling_[r]); + output_indices_[r] = (output_indices_[r] + 1) % output_dimensions_[r]; + if (r != adjustable_dimension_) { + auto new_input_index = output_indices_[r] / dimension_ceiling_[r]; + linear_input_index_ += + (new_input_index - input_indices_[r]) * index_factors_[r]; + + input_indices_[r] = new_input_index; + + auto new_carriage_sum_increment = + cumulative_dimensions_[r] * + (output_indices_[r] % dimension_ceiling_[r]); - return *result; + adjustable_dimension_carriage_sum_ = adjustable_dimension_carriage_sum_ - + old_carriage_sum_increment + + new_carriage_sum_increment; + } + + if (output_indices_[r] != 0) { + // No more carries to higher indices. + break; + } + } + auto old_adjustable_dimension_input_index = + input_indices_[adjustable_dimension_]; + RecomputeInputAdjustableDimensionIndex(); + linear_input_index_ += (input_indices_[adjustable_dimension_] - + old_adjustable_dimension_input_index) * + index_factors_[adjustable_dimension_]; } -template // both types are needed here b/c IndexVecT and - // InputDataT are not related - void - fill_periodic_tensor( - tensorflow::OpKernelContext* context, - const IndexVecT& desired_shape, - const tensorflow::Tensor& input_tensor) { - // input is a strided array (last index is fastest, C-ordered) - auto input = input_tensor.flat(); - const int rank = input_tensor.dims(); - // original and target dimensions - std::vector original_dimensions(rank), - target_dimensions(rank); - tensorflow::int64 total_size(input_tensor.NumElements()), new_sliced_size(1); - // factors by which original_dimensions increases/decreases w.r.t. - // target_dimensions - std::vector dimension_ceiling(rank), - cumulative_dimensions(rank); - // index of adjustable dimension - int adjustable_dimension; - tensorflow::TensorShape output_shape; +std::vector InputIndexer::TensorShapeToVector( + const tensorflow::TensorShape& tensor_shape) { + std::vector result(tensor_shape.dims()); + int count = 0; + for (const auto dim_info : tensor_shape) { + result[count] = dim_info.size; + ++count; + } + return result; +} - // requires that the rank of the input tensor and length of the desired shape - // are equal - OP_REQUIRES(context, rank == desired_shape.size(), - tensorflow::errors::InvalidArgument( - "periodic_resample expects the rank of the input tensor, ", - rank, ", to be the same as the length of the desired shape, ", - desired_shape.size(), ".")); +std::vector InputIndexer::ComputeDimensionCeiling( + const std::vector& output_dimensions, + const std::vector& input_dimensions) { + std::vector dimension_ceiling(input_dimensions.size()); + for (size_t i = 0; i < input_dimensions.size(); ++i) { + dimension_ceiling[i] = (output_dimensions[i] + input_dimensions[i] - 1) / + input_dimensions[i]; + } + return dimension_ceiling; +} - bool found = false; - const auto& input_tensor_shape = input_tensor.shape(); +std::vector InputIndexer::ComputeCumulativeDimensions() { + std::vector cumulative_dimensions(rank_); + int count = 0; + for (int i = 0; i < rank_; ++i) { + if (count == 0) { + cumulative_dimensions[count] = 1; + } else { + cumulative_dimensions[count] = + cumulative_dimensions[count - 1] * dimension_ceiling_[count - 1]; + } + ++count; + } + return cumulative_dimensions; +} +template +void process_desired_shape(tensorflow::OpKernelContext* context, + const tensorflow::TensorShape& input_tensor_shape, + const IndexVecT& desired_shape, + int* adjustable_dimension, + std::vector* target_dimensions, + tensorflow::int64* output_size) { + tensorflow::int64 new_sliced_size = 1; + bool found = false; + const int rank = input_tensor_shape.dims(); for (int i = 0; i < rank; ++i) { - // if (desired_shape(i) < 1) { if (desired_shape[i] < 1) { // only one index can be adjustable OP_REQUIRES(context, !found, tensorflow::errors::InvalidArgument( "periodic_resample expects only " "one index to be marked as adjustable.")); - adjustable_dimension = i; + *adjustable_dimension = i; found = true; } else { OP_REQUIRES( @@ -122,9 +232,8 @@ template +void +do_periodic_resample_op(tensorflow::OpKernelContext* context, + const tensorflow::TensorShape& original_shape, + const tensorflow::PartialTensorShape& desired_shape, + const tensorflow::Tensor& source_tensor) { + const int rank = source_tensor.dims(); + + // requires that the rank of the input tensor and length of the desired shape + // are equal + OP_REQUIRES(context, rank == desired_shape.dims(), + tensorflow::errors::InvalidArgument( + "periodic_resample expects the rank of the input tensor, ", + rank, ", to be the same as the length of the desired shape, ", + desired_shape.dims(), ".")); + + std::vector target_dimensions(rank); + tensorflow::int64 new_size = 0; + // index of adjustable dimension + int adjustable_dimension = 0; + process_desired_shape(context, original_shape, desired_shape.dim_sizes(), + &adjustable_dimension, &target_dimensions, &new_size); // ensure that the new dimension is greater than zero OP_REQUIRES(context, target_dimensions[adjustable_dimension] > 0, @@ -160,11 +293,14 @@ template allocate_output(0, output_shape, &output_tensor)); auto output = output_tensor->flat(); - // memory is allocated for these variables outside the inner loop for - // efficiency (although, I could create a separate class scope for - // this purpose instead) - tensorflow::int64 result = 0; - std::vector output_indices(target_dimensions.size()); + // input is a strided array (last index is fastest, C-ordered) + auto input = source_tensor.flat(); // Fill output tensor with periodically resampled input tensor values - for (tensorflow::int64 output_index = 0; output_index < new_size; - ++output_index) { - output(output_index) = input(compute_input_index( - &target_dimensions, output_index, original_dimensions, - adjustable_dimension, dimension_ceiling, cumulative_dimensions, &result, - &output_indices, rank)); - } + InputIndexer input_indexer(target_dimensions, original_shape, + adjustable_dimension); + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + auto fill_output_tensor = [&input_indexer, &output, &input]( + tensorflow::int64 start, tensorflow::int64 limit) { + InputIndexer local_indexer(input_indexer); + local_indexer.MoveToOutputIndex(start); + for (tensorflow::int64 output_index = start; output_index < limit; + ++output_index) { + if (mode == Mode::kForward) { + output(output_index) = input(local_indexer.linear_input_index()); + } else { + output(local_indexer.linear_input_index()) = input(output_index); + } + local_indexer.IncrementOutputIndex(); + } + }; + ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers, + new_size, costPerFillIndex, fill_output_tensor); } +#define DATA_TYPE_SWITCH(data_type, context, CASE) \ + switch (data_type) { \ + CASE(float) \ + CASE(double) \ + CASE(tensorflow::int32) \ + CASE(tensorflow::int64) \ + default: \ + context->CtxFailure(__FILE__, __LINE__, \ + tensorflow::errors::InvalidArgument( \ + "Unsuppored tensor elements type")); \ + break; \ + } + void create_output_tensor( tensorflow::OpKernelContext* context, const tensorflow::Tensor& input_tensor, const tensorflow::DataType& input_tensor_type, - const tensorflow::PartialTensorShape& desired_shape_tensor) { - auto desired_shape = desired_shape_tensor.dim_sizes(); - - // obligatory type switch - switch (input_tensor_type) { - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, input_tensor); + const tensorflow::PartialTensorShape& desired_shape) { +#define CASE(type) \ + case tensorflow::DataTypeToEnum::value: \ + do_periodic_resample_op( \ + context, input_tensor.shape(), desired_shape, input_tensor); \ break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, input_tensor); - break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, - input_tensor); - break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, - input_tensor); + + DATA_TYPE_SWITCH(input_tensor_type, context, CASE); +#undef CASE +} + +void create_grad_tensor(tensorflow::OpKernelContext* context, + const tensorflow::Tensor& grad_tensor, + const tensorflow::DataType& grad_tensor_type, + const tensorflow::TensorShape& original_shape, + const tensorflow::PartialTensorShape& desired_shape) { +#define CASE(type) \ + case tensorflow::DataTypeToEnum::value: \ + do_periodic_resample_op( \ + context, original_shape, desired_shape, grad_tensor); \ break; - default:; - } + + DATA_TYPE_SWITCH(grad_tensor_type, context, CASE); +#undef CASE } } // namespace @@ -238,4 +400,25 @@ class PeriodicResampleOp : public tensorflow::OpKernel { tensorflow::PartialTensorShape desired_shape; }; +class PeriodicResampleOpGrad : public tensorflow::OpKernel { + public: + explicit PeriodicResampleOpGrad(tensorflow::OpKernelConstruction* context) + : tensorflow::OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("original_shape", &original_shape)); + OP_REQUIRES_OK(context, context->GetAttr("desired_shape", &desired_shape)); + } + + void Compute(tensorflow::OpKernelContext* context) override { + const tensorflow::Tensor& grad_tensor = context->input(0); + const tensorflow::DataType grad_tensor_type = context->input_dtype(0); + create_grad_tensor(context, grad_tensor, grad_tensor_type, original_shape, + desired_shape); + } + + private: + tensorflow::TensorShape original_shape; + tensorflow::PartialTensorShape desired_shape; +}; + #endif // TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_ diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc index 82bd79695646e3673c2c78ad99dd2bd200fc2fbf..fd38cd09b4d0939d7955f7839763a8e955b71fa5 100644 --- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc +++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc @@ -26,7 +26,42 @@ REGISTER_OP("PeriodicResample") .Input("values: T") .Attr("shape: shape") .Output("output: T") - .SetShapeFn(shape_inference::ExplicitShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + tensorflow::PartialTensorShape desired_shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape)); + shape_inference::ShapeHandle input_tensor_shape = c->input(0); + shape_inference::DimensionHandle num_input_elements = + c->NumElements(input_tensor_shape); + shape_inference::ShapeHandle result_shape_handle; + if (!shape_inference::InferenceContext::ValueKnown(num_input_elements)) { + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + desired_shape, &result_shape_handle)); + } else { + const int rank = c->Rank(input_tensor_shape); + std::vector target_dimensions(rank); + tensorflow::int64 new_sliced_size = 1; + int adjustable_dimension = 0; + for (int i = 0; i < rank; ++i) { + if (desired_shape.dim_size(i) < 1) { + adjustable_dimension = i; + } else { + target_dimensions[i] = desired_shape.dim_size(i); + new_sliced_size *= target_dimensions[i]; + } + } + target_dimensions[adjustable_dimension] = + shape_inference::InferenceContext::Value( + num_input_elements) / new_sliced_size; + tensorflow::TensorShape result_shape; + for (int i = 0; i < rank; ++i) { + result_shape.AddDim(target_dimensions[i]); + } + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape( + result_shape, &result_shape_handle)); + } + c->set_output(0, result_shape_handle); + return Status::OK(); + }) .Doc(R"doc( Periodically resample elements of a tensor to conform to `shape`. @@ -101,4 +136,20 @@ output: Periodically resampled tensor that has dimensions specified as in )doc"); + +REGISTER_OP("PeriodicResampleOpGrad") + .Attr("T: numbertype") + .Input("grad: T") + .Attr("original_shape: shape") + .Attr("desired_shape: shape") + .Output("grad_values: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + tensorflow::TensorShape original_shape; + TF_RETURN_IF_ERROR(c->GetAttr("original_shape", &original_shape)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(original_shape, &s)); + c->set_output(0, s); + return Status::OK(); +}); + } // namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..55edf76fcd3eed461e1465b569e1c2e9e2facbc0 --- /dev/null +++ b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(ArrayOpsTest, PeriodicResample_ShapeFn) { + ShapeInferenceTestOp op("PeriodicResample"); + // Case 1: output shape can be fully inferreed. + PartialTensorShape shape({4, 4, -1}); + TensorShapeProto shape_proto; + shape.AsProto(&shape_proto); + + TF_ASSERT_OK(NodeDefBuilder("test", "PeriodicResample") + .Input({"values", 0, DT_INT32}) + .Attr("shape", shape_proto) + .Finalize(&op.node_def)); + INFER_OK(op, "[2,2,4]", "[4,4,1]"); + // Case 2: output shape can not be inferred - report desired shape. + INFER_OK(op, "[2,2,?]", "[4,4,?]"); +} + +} // end namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py index a25de55e18b223db2b724aafb54b18d8f48a5baa..31a6fe1d94b8a972087e00cf7c676105b0f1129b 100644 --- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py +++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py @@ -21,8 +21,11 @@ from __future__ import print_function import numpy from tensorflow.contrib.periodic_resample import periodic_resample +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -93,7 +96,6 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): def testPeriodicResampleErrors(self): input_tensor = numpy.zeros(shape=[1, 2, 2, 4]) with self.test_session(): - variables.global_variables_initializer().run() with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, 'Dimension 3 input tensor has size 4, desired shape has size 1'): @@ -103,6 +105,29 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): '4, to be the same as the length of the desired shape, 3'): periodic_resample(input_tensor, [None, 4, 4]).eval() + def testPeriodicResampleGradient(self): + desired_shape = numpy.array([4, 4, None]) + result_shape = (4, 4, 1) + input_shape = (2, 2, 4) + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32, shape=input_shape) + output = periodic_resample(x, desired_shape) + error = gradient_checker.compute_gradient_error( + x, input_shape, output, result_shape) + self.assertLess(error, 1e-4) + + def testPeriodicResampleShapeInference(self): + with self.test_session() as sess: + # Case 1: output shape can be fully inferreed. + x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4)) + output = periodic_resample(x, [4, 4, None]) + self.assertEqual(output.shape, [4, 4, 1]) + # Case 2: output shape can not be inferred - report desired shape. + x = array_ops.placeholder(dtypes.float32, shape=(2, 2, None)) + output = periodic_resample(x, [4, 4, None]) + self.assertTrue(output.shape.is_compatible_with([4, 4, None])) + self.assertEqual(output.shape[2].value, None) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py index 348623d8f8d0c2ed60f559eca281343722038100..470e300ccbe7108fd49718341f4a522683366fe3 100644 --- a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py +++ b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py @@ -21,11 +21,17 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.contrib.periodic_resample.python.ops import gen_periodic_resample_op -from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample +from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample, periodic_resample_op_grad from tensorflow.contrib.util import loader +from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader # pylint: enable=unused-import _periodic_resample_op = loader.load_op_library( resource_loader.get_path_to_datafile('_periodic_resample_op.so')) + +@ops.RegisterGradient("PeriodicResample") +def _periodic_resample_grad_cc(op, grad): + return periodic_resample_op_grad( + grad, op.inputs[0].shape, op.get_attr('shape')) diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py index b7a98c68e2343e9c8bb4b41556dc96bfe4ef444c..af3b2ad1b531b835f484a155efcc57bbe634f2df 100644 --- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py @@ -34,7 +34,8 @@ class ContribEstimatorPredictor(predictor.Predictor): prediction_input_fn, input_alternative_key=None, output_alternative_key=None, - graph=None): + graph=None, + config=None): """Initialize a `ContribEstimatorPredictor`. Args: @@ -48,6 +49,7 @@ class ContribEstimatorPredictor(predictor.Predictor): multi-headed models. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. """ self._graph = graph or ops.Graph() with self._graph.as_default(): @@ -58,6 +60,7 @@ class ContribEstimatorPredictor(predictor.Predictor): checkpoint_path = saver.latest_checkpoint(estimator.model_dir) self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( + config=config, checkpoint_filename_with_path=checkpoint_path)) input_alternative_key = ( diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py index d78d94c2699b14c80e7decee2181d190a6d91f99..a725072e72df2db64cde5ea31ab16e7c2dc5d2ce 100644 --- a/tensorflow/contrib/predictor/core_estimator_predictor.py +++ b/tensorflow/contrib/predictor/core_estimator_predictor.py @@ -51,7 +51,8 @@ class CoreEstimatorPredictor(predictor.Predictor): estimator, serving_input_receiver_fn, output_key=None, - graph=None): + graph=None, + config=None): """Initialize a `CoreEstimatorPredictor`. Args: @@ -62,6 +63,7 @@ class CoreEstimatorPredictor(predictor.Predictor): `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. """ self._graph = graph or ops.Graph() with self._graph.as_default(): @@ -71,6 +73,7 @@ class CoreEstimatorPredictor(predictor.Predictor): checkpoint_dir = estimator.model_dir self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( + config=config, checkpoint_dir=checkpoint_dir)) feed_tensor_info = signature_def.inputs diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py index 6e77e934fe19851eea9ed0b74eb7aecc76f6237a..f275bc15adfa0a51a48964dff8edddbd45500e45 100644 --- a/tensorflow/contrib/predictor/predictor_factories.py +++ b/tensorflow/contrib/predictor/predictor_factories.py @@ -30,7 +30,8 @@ def from_contrib_estimator(estimator, prediction_input_fn, input_alternative_key=None, output_alternative_key=None, - graph=None): + graph=None, + config=None): """Constructs a `Predictor` from a `tf.contrib.learn.Estimator`. Args: @@ -44,6 +45,7 @@ def from_contrib_estimator(estimator, multi-headed models. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Returns: An initialized `Predictor`. @@ -62,13 +64,15 @@ def from_contrib_estimator(estimator, prediction_input_fn, input_alternative_key=input_alternative_key, output_alternative_key=output_alternative_key, - graph=graph) + graph=graph, + config=config) def from_estimator(estimator, serving_input_receiver_fn, output_key=None, - graph=None): + graph=None, + config=None): """Constructs a `Predictor` from a `tf.python.estimator.Estimator`. Args: @@ -79,6 +83,7 @@ def from_estimator(estimator, `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Returns: An initialized `Predictor`. @@ -93,14 +98,19 @@ def from_estimator(estimator, 'tf.contrib.learn.Estimator. You likely want to call ' 'from_contrib_estimator.') return core_estimator_predictor.CoreEstimatorPredictor( - estimator, serving_input_receiver_fn, output_key=output_key, graph=graph) + estimator, + serving_input_receiver_fn, + output_key=output_key, + graph=graph, + config=config) def from_saved_model(export_dir, signature_def_key=None, signature_def=None, tags=None, - graph=None): + graph=None, + config=None): """Constructs a `Predictor` from a `SavedModel` on disk. Args: @@ -115,6 +125,7 @@ def from_saved_model(export_dir, `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Returns: An initialized `Predictor`. @@ -128,4 +139,5 @@ def from_saved_model(export_dir, signature_def_key=signature_def_key, signature_def=signature_def, tags=tags, - graph=graph) + graph=graph, + config=config) diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py index 578d9424b25dd38f1d77a267d1fdf1ff9ff2da88..a2ef1dc3af0986afacf646f0dc04b7ef857a7f93 100644 --- a/tensorflow/contrib/predictor/predictor_factories_test.py +++ b/tensorflow/contrib/predictor/predictor_factories_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.predictor import predictor_factories from tensorflow.contrib.predictor import testing_common +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import test MODEL_DIR_NAME = 'contrib/predictor/test_export_dir' @@ -41,6 +42,11 @@ class PredictorFactoriesTest(test.TestCase): """Test loading from_saved_model with tags.""" predictor_factories.from_saved_model(self._export_dir, tags='serve') + def testFromSavedModelWithSessionConfig(self): + """Test loading from_saved_model with session config.""" + predictor_factories.from_saved_model( + self._export_dir, config=config_pb2.ConfigProto()) + def testFromSavedModelWithBadTags(self): """Test that loading fails for bad tags.""" bad_tags_regex = ('.*? could not be found in SavedModel') @@ -53,6 +59,13 @@ class PredictorFactoriesTest(test.TestCase): predictor_factories.from_contrib_estimator( estimator, input_fn, output_alternative_key='sum') + def testFromContribEstimatorWithSessionConfig(self): + estimator = testing_common.get_arithmetic_estimator(core=False) + input_fn = testing_common.get_arithmetic_input_fn(core=False) + predictor_factories.from_contrib_estimator( + estimator, input_fn, output_alternative_key='sum', + config=config_pb2.ConfigProto()) + def testFromContribEstimatorWithCoreEstimatorRaises(self): estimator = testing_common.get_arithmetic_estimator(core=True) input_fn = testing_common.get_arithmetic_input_fn(core=True) @@ -64,6 +77,12 @@ class PredictorFactoriesTest(test.TestCase): input_fn = testing_common.get_arithmetic_input_fn(core=True) predictor_factories.from_estimator(estimator, input_fn) + def testFromCoreEstimatorWithSessionConfig(self): + estimator = testing_common.get_arithmetic_estimator(core=True) + input_fn = testing_common.get_arithmetic_input_fn(core=True) + predictor_factories.from_estimator( + estimator, input_fn, config=config_pb2.ConfigProto()) + def testFromCoreEstimatorWithContribEstimatorRaises(self): estimator = testing_common.get_arithmetic_estimator(core=False) input_fn = testing_common.get_arithmetic_input_fn(core=False) diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py index 0dbca0f8136e4e618234101ee41c80bc085511c0..95da6d04edc5214d1b5c1851c4ab05c6d7080b9b 100644 --- a/tensorflow/contrib/predictor/saved_model_predictor.py +++ b/tensorflow/contrib/predictor/saved_model_predictor.py @@ -121,7 +121,8 @@ class SavedModelPredictor(predictor.Predictor): input_names=None, output_names=None, tags=None, - graph=None): + graph=None, + config=None): """Initialize a `CoreEstimatorPredictor`. Args: @@ -142,6 +143,7 @@ class SavedModelPredictor(predictor.Predictor): the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Raises: ValueError: If more than one of signature_def_key OR signature_def OR (input_names AND output_names) is specified. @@ -152,7 +154,7 @@ class SavedModelPredictor(predictor.Predictor): self._graph = graph or ops.Graph() with self._graph.as_default(): - self._session = session.Session() + self._session = session.Session(config=config) loader.load(self._session, tags.split(','), export_dir) if input_names is None: diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index b9918fdee1ece2bae1ab1459985066a35b6431be..23363617eddd2078db9052a64d70d5f8c234805d 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -155,8 +155,10 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:partitioned_variables", "//tensorflow/python:platform_test", "//tensorflow/python:session", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index c83623ec947c1550991352a9dd9a5c6ee9282290..27a933c0f945e53a1838aefd30aed82fadbbc146 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -6,7 +6,7 @@ inference. The details of the transformation implemented in this package is described here [1]. This is done using the -[fake quantization op](https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization). +[fake quantization op](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization). Literature has shown that fixed point networks provide comparable performance to floating point networks [2]. This is achieved by modeling the quantization diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 76f695dce0d1c4c104d823dcac9a4f94d3fba81d..55479bf5f74299bf09f131a6127f9f11d6192d90 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -475,7 +475,7 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): def _IsValidUnfusedBatchNorm(graph, context): """Checks that the output of the unfused batch norm has consumers.""" add_shift = graph.get_operation_by_name( - context + '/BatchNorm/batchnorm/add_1') + context + '/BatchNorm/batchnorm_1/add_1') # Ensure that the output tensor of batch norm has consumers, otherwise this # is a dangling node and not a match. return bool(add_shift.outputs[0].consumers()) @@ -568,7 +568,7 @@ def _GetBatchNormParams(graph, context, has_scaling): op_suffix_mean = '/BatchNorm/moments/Squeeze' op_suffix_variance = '/BatchNorm/moments/Squeeze_1' - op_suffix_epsilon = '/BatchNorm/batchnorm/add/y' + op_suffix_epsilon = '/BatchNorm/batchnorm_1/add/y' op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay' op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay' @@ -643,12 +643,12 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, Returns: A pair of Operations, the first is the original consumer node of the batch - norm (../BatchNorm/batchnorm/add_1), the second is the consumer node of + norm (../BatchNorm/batchnorm_1/add_1), the second is the consumer node of the folded graph (add_fold). """ mul_scale_name = 'mul_1' if has_scaling else 'mul' mul_scale = graph.get_operation_by_name(context + - '/BatchNorm/batchnorm/' + + '/BatchNorm/batchnorm_1/' + mul_scale_name) op_below = mul_scale.inputs[0].op weights = op_below.inputs[1] @@ -670,7 +670,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, ] scale_name = 'mul' if has_scaling else 'Rsqrt' scale = graph.get_operation_by_name( - context + '/BatchNorm/batchnorm/' + scale_name) + context + '/BatchNorm/batchnorm_1/' + scale_name) scale = array_ops.reshape(scale.outputs[0], new_shape, context + '/scale_reshape') @@ -698,7 +698,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, [(1, mul_fold.outputs[0])]) add_shift = graph.get_operation_by_name( - context + '/BatchNorm/batchnorm/add_1') + context + '/BatchNorm/batchnorm_1/add_1') corrected_output = conv_or_fc_folded.outputs[0] if correction_offset is not None: @@ -886,7 +886,7 @@ def _HasScaling(graph, input_to_ops_map, bn): Returns: A boolean indicating whether this batch norm layer has scaling enabled. """ - rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt') + rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm_1/Rsqrt') rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index fa5e11b4708402a4fe76a494ed59e30835ed1363..bfa9d3bf705e327091098a8e416b7902f852605a 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -516,13 +516,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): if has_scaling: if fused: return scope + '/BatchNorm_Fold/mul' - return scope + '/BatchNorm/batchnorm/mul' - return scope + '/BatchNorm/batchnorm/Rsqrt' + return scope + '/BatchNorm/batchnorm_1/mul' + return scope + '/BatchNorm/batchnorm_1/Rsqrt' def _BathNormBiasName(self, scope, fused): if fused: return scope + '/BatchNorm_Fold/bias' - return scope + '/BatchNorm/batchnorm/sub' + return scope + '/BatchNorm/batchnorm_1/sub' def _WeightInit(self, stddev): """Returns a truncated normal variable initializer. diff --git a/tensorflow/contrib/quantize/python/graph_matcher.py b/tensorflow/contrib/quantize/python/graph_matcher.py index bacc707a3abb5539b3b119c1ebc17bd7b30efc5b..aa3ca991c060b208ec71ae27e1ddc75df8a2c723 100644 --- a/tensorflow/contrib/quantize/python/graph_matcher.py +++ b/tensorflow/contrib/quantize/python/graph_matcher.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import abc +import itertools class Pattern(object): @@ -33,7 +34,7 @@ class Pattern(object): class OpTypePattern(Pattern): """A tree pattern that matches TF expressions with certain op types.""" - def __init__(self, op_type, name=None, inputs=None): + def __init__(self, op_type, name=None, inputs=None, ordered_inputs=True): """Initializes an OpTypePattern. Args: @@ -48,16 +49,25 @@ class OpTypePattern(Pattern): inputs: Optional list of `Pattern`s or strings that specify the patterns for the inputs of a matching op. If None, this pattern accepts any inputs of a matching op. + ordered_inputs: Defaults to True. If False, will match any op that + matches a permutation of the inputs. + + Raises: + ValueError: if too many inputs are provided when order_inputs is False. """ self._op_type = op_type self._name = name if inputs is None: inputs = [] + if len(inputs) > 8: + raise ValueError( + 'Only < 8 inputs are allowed when ordered_inputs is False.') self._inputs = [ input_pattern if isinstance(input_pattern, Pattern) else OpTypePattern(input_pattern) for input_pattern in inputs ] + self._ordered_inputs = ordered_inputs @property def name(self): @@ -78,12 +88,23 @@ class OpTypePattern(Pattern): if len(op.inputs) != len(self._inputs): return None - for input_tensor, input_pattern in zip(op.inputs, self._inputs): - input_match_result = input_pattern.match(input_tensor.op, input_tensor) - if input_match_result is None: - return None - match_result.merge_from(input_match_result) - return match_result + input_patterns_list = [self._inputs] + # If order doesn't matter for the inputs, then make sure we match at least + # one permutation of the inputs. + if not self._ordered_inputs: + input_patterns_list = list(itertools.permutations(self._inputs)) + + for input_patterns in input_patterns_list: + match_failed = False + for input_tensor, input_pattern in zip(op.inputs, input_patterns): + input_match_result = input_pattern.match(input_tensor.op, input_tensor) + if input_match_result is None: + match_failed = True + break + match_result.merge_from(input_match_result) + if not match_failed: + return match_result + return None class OneofPattern(Pattern): diff --git a/tensorflow/contrib/quantize/python/graph_matcher_test.py b/tensorflow/contrib/quantize/python/graph_matcher_test.py index 6d587572181c125faa02d36fb54933cff24f11c6..be741644b615416658001b385930dbe8429c82a2 100644 --- a/tensorflow/contrib/quantize/python/graph_matcher_test.py +++ b/tensorflow/contrib/quantize/python/graph_matcher_test.py @@ -22,6 +22,7 @@ from tensorflow.contrib.framework.python import ops as contrib_ops from tensorflow.contrib.layers.python.layers import initializers from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import graph_matcher +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -163,6 +164,44 @@ class GraphMatcherTest(test_util.TensorFlowTestCase): self.assertEqual(match_result.get_tensor('slice'), slicing) self.assertEqual(match_result.get_op('transpose'), transpose.op) + def test_ordered_pattern(self): + # + + + # / \ / \ + # x y and y x should both match when ordered inputs is False. + # Even when x and y are different operations. + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32, shape=[], name='x') + y = constant_op.constant(1.0, dtype=dtypes.float32) + plus = x + y + + add_pattern_a = graph_matcher.OpTypePattern( + 'Add', inputs=['Const', 'Placeholder'], ordered_inputs=False) + add_pattern_b = graph_matcher.OpTypePattern( + 'Add', inputs=['Placeholder', 'Const'], ordered_inputs=False) + add_pattern_fail = graph_matcher.OpTypePattern( + 'Add', inputs=['Const', 'Placeholder'], ordered_inputs=True) + # Both add_pattern_a and add_pattern_b should match the graph since + # ordered_input was set False. + matcher_a = graph_matcher.GraphMatcher(add_pattern_a) + self.assertEqual([ + match_result.get_op(add_pattern_a) + for match_result in matcher_a.match_graph(g) + ], [plus.op]) + matcher_b = graph_matcher.GraphMatcher(add_pattern_b) + self.assertEqual([ + match_result.get_op(add_pattern_b) + for match_result in matcher_b.match_graph(g) + ], [plus.op]) + # But if ordered_inputs is True, the inputs list match should fail if not + # specified in the right order. + matcher_fail = graph_matcher.GraphMatcher(add_pattern_fail) + self.assertEqual( + len([ + match_result.get_op(add_pattern_fail) + for match_result in matcher_fail.match_graph(g) + ]), 0) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 5c0e17dc8646ce7850e26ffaa80c0201cea456af..27069444a4bf8416b27787cb142ac9569ed99bb9 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -81,7 +81,8 @@ def LastValueQuantize(inputs, a tensor containing quantized values. """ with variable_scope.variable_scope( - None, default_name=name_prefix, values=[inputs], reuse=reuse): + None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope: + scope.set_partitioner(None) input_shape = inputs.get_shape() input_dim = len(input_shape) if per_channel: @@ -189,7 +190,8 @@ def MovingAvgQuantize(inputs, a tensor containing quantized values. """ with variable_scope.variable_scope( - None, default_name=name_prefix, values=[inputs], reuse=reuse): + None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope: + scope.set_partitioner(None) input_shape = inputs.get_shape() input_dim = len(input_shape) if per_channel: diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py index 38846796028512a722752cd83b8bda3b5b0bb77f..c2a8def48012c808da18587c8ff462fa33a363c0 100644 --- a/tensorflow/contrib/quantize/python/quant_ops_test.py +++ b/tensorflow/contrib/quantize/python/quant_ops_test.py @@ -23,6 +23,8 @@ from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -73,6 +75,36 @@ class QuantOpsTest(googletest.TestCase): self.assertGreater(max_value, 0.0) self.assertLess(max_value, 1.0) + def testVariablesNotParitioned_LastValue(self): + # Variables added should not use a default partiioner since they are + # scalar. There would be a tensorflow error thrown if the partitioner was + # respected by the rewrite. + with ops.Graph().as_default(): + with variable_scope.variable_scope( + 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)): + x = array_ops.placeholder(dtypes.float32, shape=[2]) + _ = quant_ops.LastValueQuantize( + x, + init_min=0.0, + init_max=0.0, + is_training=True, + vars_collection=_MIN_MAX_VARS) + + def testVariablesNotParitioned_MovingAvg(self): + # Variables added should not use a default partiioner since they are + # scalar. There would be a tensorflow error thrown if the partitioner was + # respected by the rewrite. + with ops.Graph().as_default(): + with variable_scope.variable_scope( + 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)): + x = array_ops.placeholder(dtypes.float32, shape=[2]) + _ = quant_ops.MovingAvgQuantize( + x, + init_min=0.0, + init_max=0.0, + is_training=True, + vars_collection=_MIN_MAX_VARS) + def _GetMinMaxValues(self, sess): min_max_vars = ops.get_collection(_MIN_MAX_VARS) self.assertEqual(len(min_max_vars), 2) diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 60616ea749cd3fb0edaf57cbe67484285ad41f75..cbba72643f7f166c473b6181edc292f695c4cbc2 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -218,8 +218,19 @@ def _FindLayersToQuantize(graph): """ input_pattern = graph_matcher.OpTypePattern('*') weight_var_pattern = graph_matcher.OpTypePattern('Variable|VariableV2') - weight_identity_pattern = graph_matcher.OpTypePattern( + weight_partition_identity_pattern = graph_matcher.OpTypePattern( 'Identity', inputs=[weight_var_pattern]) + weight_partition_concat_pattern = graph_matcher.OpTypePattern( + 'ConcatV2', inputs=[weight_partition_identity_pattern, '*', '*']) + weight_identity_pattern = graph_matcher.OpTypePattern( + 'Identity', + inputs=[ + graph_matcher.OneofPattern([ + weight_partition_identity_pattern, + weight_partition_concat_pattern, + weight_var_pattern, + ]) + ]) weight_resource_var_pattern = graph_matcher.OpTypePattern('ReadVariableOp') folded_weight_pattern = graph_matcher.OpTypePattern('Mul') @@ -233,37 +244,37 @@ def _FindLayersToQuantize(graph): weight_identity_pattern, weight_resource_var_pattern, folded_weight_pattern ]) - ]) + ], + ordered_inputs=False) folded_bias_mul_pattern = graph_matcher.OpTypePattern( - 'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern]) + 'Mul', + inputs=[graph_matcher.OpTypePattern('*'), layer_pattern], + ordered_inputs=False) post_layer_op_correction_pattern = graph_matcher.OpTypePattern( - 'Add', inputs=[folded_bias_mul_pattern, - graph_matcher.OpTypePattern('*')]) + 'Add', + inputs=[folded_bias_mul_pattern, + graph_matcher.OpTypePattern('*')], + ordered_inputs=False) folded_bias_add_pattern = graph_matcher.OpTypePattern( 'Add', inputs=[ post_layer_op_correction_pattern, graph_matcher.OpTypePattern('*') - ]) + ], + ordered_inputs=False) bias_add_pattern = graph_matcher.OpTypePattern( - 'Add|BiasAdd', inputs=[layer_pattern, '*']) + 'Add|BiasAdd', inputs=[layer_pattern, '*'], ordered_inputs=False) # The bias can come from the bias add or the folded bias add. - bypass_pattern_a = graph_matcher.OpTypePattern( + bypass_pattern = graph_matcher.OpTypePattern( 'Add', inputs=[ graph_matcher.OneofPattern( [bias_add_pattern, folded_bias_add_pattern]), '*' - ]) - bypass_pattern_b = graph_matcher.OpTypePattern( - 'Add', - inputs=[ - '*', - graph_matcher.OneofPattern( - [bias_add_pattern, folded_bias_add_pattern]) - ]) + ], + ordered_inputs=False) # The input to the activation can come from bias add, fold bias add, the # bypasses. @@ -273,15 +284,14 @@ def _FindLayersToQuantize(graph): '|'.join(_ACTIVATION_TYPES) + '|Identity', inputs=[ graph_matcher.OneofPattern([ - bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a, - bypass_pattern_b + bias_add_pattern, + folded_bias_add_pattern, + bypass_pattern, ]) ]) - post_activation_bypass_pattern_a = graph_matcher.OpTypePattern( - 'Add', inputs=['*', activation_pattern]) - post_activation_bypass_pattern_b = graph_matcher.OpTypePattern( - 'Add', inputs=[activation_pattern, '*']) + post_activation_bypass_pattern = graph_matcher.OpTypePattern( + 'Add', inputs=['*', activation_pattern], ordered_inputs=False) # The order of the following matching blocks is very important. Since matches # aren't guaranteed to be disjoint, we structure matches from largest to @@ -297,10 +307,7 @@ def _FindLayersToQuantize(graph): # to ensure we don't match only the first part of this layer, missing the # post activation bypass node. post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher( - graph_matcher.OneofPattern([ - post_activation_bypass_pattern_a, - post_activation_bypass_pattern_b, - ])) + post_activation_bypass_pattern) for match_result in post_activation_bypass_layer_matcher.match_graph(graph): layer_op = match_result.get_op(layer_pattern) weight_tensor = match_result.get_tensor(weight_identity_pattern) @@ -312,14 +319,9 @@ def _FindLayersToQuantize(graph): bias_add_op = match_result.get_op(bias_add_pattern) if bias_add_op is None: bias_add_op = match_result.get_op(folded_bias_add_pattern) - bypass_op = match_result.get_op(bypass_pattern_a) - if bypass_op is None: - bypass_op = match_result.get_op(bypass_pattern_b) + bypass_op = match_result.get_op(bypass_pattern) post_activation_bypass_op = match_result.get_op( - post_activation_bypass_pattern_a) - if post_activation_bypass_op is None: - post_activation_bypass_op = match_result.get_op( - post_activation_bypass_pattern_b) + post_activation_bypass_pattern) if layer_op not in matched_layer_set: matched_layer_set.add(layer_op) layer_matches.append( @@ -340,9 +342,7 @@ def _FindLayersToQuantize(graph): bias_add_op = match_result.get_op(bias_add_pattern) if bias_add_op is None: bias_add_op = match_result.get_op(folded_bias_add_pattern) - bypass_op = match_result.get_op(bypass_pattern_a) - if bypass_op is None: - bypass_op = match_result.get_op(bypass_pattern_b) + bypass_op = match_result.get_op(bypass_pattern) if layer_op not in matched_layer_set: matched_layer_set.add(layer_op) layer_matches.append( diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index e7360ae03ca535146dee007eeec88373adf39f12..92ca4a1b0c3126ebccf2b525f01f4d6455c4d527 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -27,6 +27,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import googletest conv2d = layers.conv2d @@ -327,6 +329,66 @@ class QuantizeTest(test_util.TensorFlowTestCase): # No ops should be inserted or removed. self.assertEqual(op_names_before_quantize, op_names_after_quantize) + def testSinglePartitionedVariable(self): + self._RunTestOverParameters(self._testSinglePartitionedVariable) + + def _testSinglePartitionedVariable(self, is_training): + # When weights are partitioned into a single partition, the weights variable + # is followed by a identity -> identity (An additional identity node). + partitioner = partitioned_variables.fixed_size_partitioner(1) + graph = ops.Graph() + with graph.as_default(): + with variable_scope.variable_scope('part', partitioner=partitioner): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test/test') + node = math_ops.add(conv, input2, name='test/add') + node = nn_ops.relu6(node, name='test/relu6') + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Check that the weight's quant node was added. + op_names = [op.name for op in graph.get_operations()] + self.assertTrue( + 'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names) + + def testMultiplePartitionedVariables(self): + self._RunTestOverParameters(self._testMultiplePartitionedVariables) + + def _testMultiplePartitionedVariables(self, is_training): + # When weights are partitioned into multiple partitions the weights variable + # is followed by a identity -> concat -> identity to group the partitions. + partitioner = partitioned_variables.fixed_size_partitioner(2) + graph = ops.Graph() + with graph.as_default(): + with variable_scope.variable_scope('part', partitioner=partitioner): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test/test') + node = math_ops.add(conv, input2, name='test/add') + node = nn_ops.relu6(node, name='test/relu6') + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Check that the weight's quant node was added. + op_names = [op.name for op in graph.get_operations()] + self.assertTrue( + 'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names) + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md index 3ff85faf611afad71b6e6203453bbe97c56f9242..79b015a9163f5727caa40b54579c71e57621c92f 100644 --- a/tensorflow/contrib/receptive_field/README.md +++ b/tensorflow/contrib/receptive_field/README.md @@ -6,6 +6,32 @@ region your output features depend on. Better yet, using the parameters computed by the library, you can easily find the exact image region which is used to compute each convnet feature. +This library can be used to compute receptive field parameters of popular +convnets: + +

+ +convnet model | receptive field | effective stride | effective padding +:-----------------: | :-------------: | :--------------: | :---------------: +alexnet_v2 | 195 | 32 | 64 +vgg_16 | 212 | 32 | 90 +inception_v2 | 699 | 32 | 318 +inception_v3 | 1311 | 32 | 618 +inception_v4 | 2071 | 32 | 998 +inception_resnet_v2 | 3039 | 32 | 1482 +mobilenet_v1 | 315 | 32 | 126 +mobilenet_v1_075 | 315 | 32 | 126 +resnet_v1_50 | 483 | 32 | 241 +resnet_v1_101 | 1027 | 32 | 513 +resnet_v1_152 | 1507 | 32 | 753 +resnet_v1_200 | 1763 | 32 | 881 + +
+ +A comprehensive table with pre-computed receptive field parameters for different +end-points, input resolutions, and other variants of these networks can be found +[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md). + ## Basic usage The main function to be called is `compute_receptive_field_from_graph_def`, @@ -96,9 +122,9 @@ The script will write to stdout the receptive field parameters for many variants of several popular convnets: AlexNet, VGG, ResNet, Inception, Mobilenet. They are also written to the file `/tmp/rf_benchmark_results.csv`. -TODO: include here a plot for receptive field sizes of different convnets. - -TODO: include table/link to pre-computed RF parameters. +A comprehensive table with pre-computed receptive field parameters for different +networks can be found +[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md). ## Compute RF parameters from a graph pbtxt diff --git a/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md b/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md new file mode 100644 index 0000000000000000000000000000000000000000..736fbef6e7c66176e74144115f0b1acd6bf6cd2f --- /dev/null +++ b/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md @@ -0,0 +1,629 @@ +# Pre-computed receptive field parameters + +## Table with results + +The table below presents the receptive field parameters for several popular +convolutional neural networks. These are computed using the models from the +[TF-Slim +repository](https://github.com/tensorflow/models/tree/master/research/slim), +by using the [rf_benchmark +script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py). + +Questions? See the [FAQ](#faq). + +CNN | resolution | end-point | RF | effective stride | effective padding +:----------------------------: | :--------: | :------------------: | :--: | :--------------: | :---------------: +alexnet_v2 | None | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | None | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | None | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | None | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | None | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | None | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | None | alexnet_v2/pool5 | 195 | 32 | 64 +alexnet_v2 | 224 | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | 224 | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | 224 | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | 224 | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | 224 | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | 224 | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | 224 | alexnet_v2/pool5 | 195 | 32 | 64 +alexnet_v2 | 321 | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | 321 | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | 321 | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | 321 | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | 321 | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | 321 | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | 321 | alexnet_v2/pool5 | 195 | 32 | 64 +vgg_a | None | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | None | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | None | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | None | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | None | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | None | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | None | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | None | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | None | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | None | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | None | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | None | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | None | vgg_a/pool5 | 150 | 32 | 59 +vgg_a | 224 | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | 224 | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | 224 | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | 224 | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | 224 | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | 224 | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | 224 | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | 224 | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | 224 | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | 224 | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | 224 | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | 224 | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | 224 | vgg_a/pool5 | 150 | 32 | 59 +vgg_a | 321 | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | 321 | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | 321 | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | 321 | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | 321 | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | 321 | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | 321 | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | 321 | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | 321 | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | 321 | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | 321 | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | 321 | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | 321 | vgg_a/pool5 | 150 | 32 | 59 +vgg_16 | None | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | None | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | None | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | None | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | None | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | None | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | None | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | None | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | None | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | None | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | None | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | None | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | None | vgg_16/pool5 | 212 | 32 | 90 +vgg_16 | 224 | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | 224 | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | 224 | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | 224 | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | 224 | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | 224 | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | 224 | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | 224 | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | 224 | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | 224 | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | 224 | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | 224 | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | 224 | vgg_16/pool5 | 212 | 32 | 90 +vgg_16 | 321 | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | 321 | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | 321 | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | 321 | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | 321 | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | 321 | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | 321 | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | 321 | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | 321 | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | 321 | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | 321 | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | 321 | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | 321 | vgg_16/pool5 | 212 | 32 | 90 +inception_v2 | None | Conv2d_1a_7x7 | 7 | 2 | None +inception_v2 | None | MaxPool_2a_3x3 | 11 | 4 | None +inception_v2 | None | Conv2d_2b_1x1 | 11 | 4 | None +inception_v2 | None | Conv2d_2c_3x3 | 19 | 4 | None +inception_v2 | None | MaxPool_3a_3x3 | 27 | 8 | None +inception_v2 | None | Mixed_3b | 59 | 8 | None +inception_v2 | None | Mixed_3c | 91 | 8 | None +inception_v2 | None | Mixed_4a | 123 | 16 | None +inception_v2 | None | Mixed_4b | 187 | 16 | None +inception_v2 | None | Mixed_4c | 251 | 16 | None +inception_v2 | None | Mixed_4d | 315 | 16 | None +inception_v2 | None | Mixed_4e | 379 | 16 | None +inception_v2 | None | Mixed_5a | 443 | 32 | None +inception_v2 | None | Mixed_5b | 571 | 32 | None +inception_v2 | None | Mixed_5c | 699 | 32 | None +inception_v2 | 224 | Conv2d_1a_7x7 | 7 | 2 | 2 +inception_v2 | 224 | MaxPool_2a_3x3 | 11 | 4 | 2 +inception_v2 | 224 | Conv2d_2b_1x1 | 11 | 4 | 2 +inception_v2 | 224 | Conv2d_2c_3x3 | 19 | 4 | 6 +inception_v2 | 224 | MaxPool_3a_3x3 | 27 | 8 | 6 +inception_v2 | 224 | Mixed_3b | 59 | 8 | 22 +inception_v2 | 224 | Mixed_3c | 91 | 8 | 38 +inception_v2 | 224 | Mixed_4a | 123 | 16 | 46 +inception_v2 | 224 | Mixed_4b | 187 | 16 | 78 +inception_v2 | 224 | Mixed_4c | 251 | 16 | 110 +inception_v2 | 224 | Mixed_4d | 315 | 16 | 142 +inception_v2 | 224 | Mixed_4e | 379 | 16 | 174 +inception_v2 | 224 | Mixed_5a | 443 | 32 | 190 +inception_v2 | 224 | Mixed_5b | 571 | 32 | 254 +inception_v2 | 224 | Mixed_5c | 699 | 32 | 318 +inception_v2 | 321 | Conv2d_1a_7x7 | 7 | 2 | 3 +inception_v2 | 321 | MaxPool_2a_3x3 | 11 | 4 | 5 +inception_v2 | 321 | Conv2d_2b_1x1 | 11 | 4 | 5 +inception_v2 | 321 | Conv2d_2c_3x3 | 19 | 4 | 9 +inception_v2 | 321 | MaxPool_3a_3x3 | 27 | 8 | 13 +inception_v2 | 321 | Mixed_3b | 59 | 8 | 29 +inception_v2 | 321 | Mixed_3c | 91 | 8 | 45 +inception_v2 | 321 | Mixed_4a | 123 | 16 | 61 +inception_v2 | 321 | Mixed_4b | 187 | 16 | 93 +inception_v2 | 321 | Mixed_4c | 251 | 16 | 125 +inception_v2 | 321 | Mixed_4d | 315 | 16 | 157 +inception_v2 | 321 | Mixed_4e | 379 | 16 | 189 +inception_v2 | 321 | Mixed_5a | 443 | 32 | 221 +inception_v2 | 321 | Mixed_5b | 571 | 32 | 285 +inception_v2 | 321 | Mixed_5c | 699 | 32 | 349 +inception_v2-no-separable-conv | None | Conv2d_1a_7x7 | 7 | 2 | None +inception_v2-no-separable-conv | None | MaxPool_2a_3x3 | 11 | 4 | None +inception_v2-no-separable-conv | None | Conv2d_2b_1x1 | 11 | 4 | None +inception_v2-no-separable-conv | None | Conv2d_2c_3x3 | 19 | 4 | None +inception_v2-no-separable-conv | None | MaxPool_3a_3x3 | 27 | 8 | None +inception_v2-no-separable-conv | None | Mixed_3b | 59 | 8 | None +inception_v2-no-separable-conv | None | Mixed_3c | 91 | 8 | None +inception_v2-no-separable-conv | None | Mixed_4a | 123 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4b | 187 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4c | 251 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4d | 315 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4e | 379 | 16 | None +inception_v2-no-separable-conv | None | Mixed_5a | 443 | 32 | None +inception_v2-no-separable-conv | None | Mixed_5b | 571 | 32 | None +inception_v2-no-separable-conv | None | Mixed_5c | 699 | 32 | None +inception_v2-no-separable-conv | 224 | Conv2d_1a_7x7 | 7 | 2 | 2 +inception_v2-no-separable-conv | 224 | MaxPool_2a_3x3 | 11 | 4 | 2 +inception_v2-no-separable-conv | 224 | Conv2d_2b_1x1 | 11 | 4 | 2 +inception_v2-no-separable-conv | 224 | Conv2d_2c_3x3 | 19 | 4 | 6 +inception_v2-no-separable-conv | 224 | MaxPool_3a_3x3 | 27 | 8 | 6 +inception_v2-no-separable-conv | 224 | Mixed_3b | 59 | 8 | 22 +inception_v2-no-separable-conv | 224 | Mixed_3c | 91 | 8 | 38 +inception_v2-no-separable-conv | 224 | Mixed_4a | 123 | 16 | 46 +inception_v2-no-separable-conv | 224 | Mixed_4b | 187 | 16 | 78 +inception_v2-no-separable-conv | 224 | Mixed_4c | 251 | 16 | 110 +inception_v2-no-separable-conv | 224 | Mixed_4d | 315 | 16 | 142 +inception_v2-no-separable-conv | 224 | Mixed_4e | 379 | 16 | 174 +inception_v2-no-separable-conv | 224 | Mixed_5a | 443 | 32 | 190 +inception_v2-no-separable-conv | 224 | Mixed_5b | 571 | 32 | 254 +inception_v2-no-separable-conv | 224 | Mixed_5c | 699 | 32 | 318 +inception_v2-no-separable-conv | 321 | Conv2d_1a_7x7 | 7 | 2 | 3 +inception_v2-no-separable-conv | 321 | MaxPool_2a_3x3 | 11 | 4 | 5 +inception_v2-no-separable-conv | 321 | Conv2d_2b_1x1 | 11 | 4 | 5 +inception_v2-no-separable-conv | 321 | Conv2d_2c_3x3 | 19 | 4 | 9 +inception_v2-no-separable-conv | 321 | MaxPool_3a_3x3 | 27 | 8 | 13 +inception_v2-no-separable-conv | 321 | Mixed_3b | 59 | 8 | 29 +inception_v2-no-separable-conv | 321 | Mixed_3c | 91 | 8 | 45 +inception_v2-no-separable-conv | 321 | Mixed_4a | 123 | 16 | 61 +inception_v2-no-separable-conv | 321 | Mixed_4b | 187 | 16 | 93 +inception_v2-no-separable-conv | 321 | Mixed_4c | 251 | 16 | 125 +inception_v2-no-separable-conv | 321 | Mixed_4d | 315 | 16 | 157 +inception_v2-no-separable-conv | 321 | Mixed_4e | 379 | 16 | 189 +inception_v2-no-separable-conv | 321 | Mixed_5a | 443 | 32 | 221 +inception_v2-no-separable-conv | 321 | Mixed_5b | 571 | 32 | 285 +inception_v2-no-separable-conv | 321 | Mixed_5c | 699 | 32 | 349 +inception_v3 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | None | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | None | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | None | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | None | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | None | Mixed_5b | 63 | 8 | 18 +inception_v3 | None | Mixed_5c | 95 | 8 | 34 +inception_v3 | None | Mixed_5d | 127 | 8 | 50 +inception_v3 | None | Mixed_6a | 159 | 16 | 58 +inception_v3 | None | Mixed_6b | 351 | 16 | 154 +inception_v3 | None | Mixed_6c | 543 | 16 | 250 +inception_v3 | None | Mixed_6d | 735 | 16 | 346 +inception_v3 | None | Mixed_6e | 927 | 16 | 442 +inception_v3 | None | Mixed_7a | 1055 | 32 | 490 +inception_v3 | None | Mixed_7b | 1183 | 32 | 554 +inception_v3 | None | Mixed_7c | 1311 | 32 | 618 +inception_v3 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | 224 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | 224 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | 224 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | 224 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | 224 | Mixed_5b | 63 | 8 | 18 +inception_v3 | 224 | Mixed_5c | 95 | 8 | 34 +inception_v3 | 224 | Mixed_5d | 127 | 8 | 50 +inception_v3 | 224 | Mixed_6a | 159 | 16 | 58 +inception_v3 | 224 | Mixed_6b | 351 | 16 | 154 +inception_v3 | 224 | Mixed_6c | 543 | 16 | 250 +inception_v3 | 224 | Mixed_6d | 735 | 16 | 346 +inception_v3 | 224 | Mixed_6e | 927 | 16 | 442 +inception_v3 | 224 | Mixed_7a | 1055 | 32 | 490 +inception_v3 | 224 | Mixed_7b | 1183 | 32 | 554 +inception_v3 | 224 | Mixed_7c | 1311 | 32 | 618 +inception_v3 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | 321 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | 321 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | 321 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | 321 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | 321 | Mixed_5b | 63 | 8 | 18 +inception_v3 | 321 | Mixed_5c | 95 | 8 | 34 +inception_v3 | 321 | Mixed_5d | 127 | 8 | 50 +inception_v3 | 321 | Mixed_6a | 159 | 16 | 58 +inception_v3 | 321 | Mixed_6b | 351 | 16 | 154 +inception_v3 | 321 | Mixed_6c | 543 | 16 | 250 +inception_v3 | 321 | Mixed_6d | 735 | 16 | 346 +inception_v3 | 321 | Mixed_6e | 927 | 16 | 442 +inception_v3 | 321 | Mixed_7a | 1055 | 32 | 490 +inception_v3 | 321 | Mixed_7b | 1183 | 32 | 554 +inception_v3 | 321 | Mixed_7c | 1311 | 32 | 618 +inception_v4 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | None | Mixed_3a | 15 | 4 | 2 +inception_v4 | None | Mixed_4a | 47 | 4 | 14 +inception_v4 | None | Mixed_5a | 55 | 8 | 14 +inception_v4 | None | Mixed_5b | 87 | 8 | 30 +inception_v4 | None | Mixed_5c | 119 | 8 | 46 +inception_v4 | None | Mixed_5d | 151 | 8 | 62 +inception_v4 | None | Mixed_5e | 183 | 8 | 78 +inception_v4 | None | Mixed_6a | 215 | 16 | 86 +inception_v4 | None | Mixed_6b | 407 | 16 | 182 +inception_v4 | None | Mixed_6c | 599 | 16 | 278 +inception_v4 | None | Mixed_6d | 791 | 16 | 374 +inception_v4 | None | Mixed_6e | 983 | 16 | 470 +inception_v4 | None | Mixed_6f | 1175 | 16 | 566 +inception_v4 | None | Mixed_6g | 1367 | 16 | 662 +inception_v4 | None | Mixed_6h | 1559 | 16 | 758 +inception_v4 | None | Mixed_7a | 1687 | 32 | 806 +inception_v4 | None | Mixed_7b | 1815 | 32 | 870 +inception_v4 | None | Mixed_7c | 1943 | 32 | 934 +inception_v4 | None | Mixed_7d | 2071 | 32 | 998 +inception_v4 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | 224 | Mixed_3a | 15 | 4 | 2 +inception_v4 | 224 | Mixed_4a | 47 | 4 | 14 +inception_v4 | 224 | Mixed_5a | 55 | 8 | 14 +inception_v4 | 224 | Mixed_5b | 87 | 8 | 30 +inception_v4 | 224 | Mixed_5c | 119 | 8 | 46 +inception_v4 | 224 | Mixed_5d | 151 | 8 | 62 +inception_v4 | 224 | Mixed_5e | 183 | 8 | 78 +inception_v4 | 224 | Mixed_6a | 215 | 16 | 86 +inception_v4 | 224 | Mixed_6b | 407 | 16 | 182 +inception_v4 | 224 | Mixed_6c | 599 | 16 | 278 +inception_v4 | 224 | Mixed_6d | 791 | 16 | 374 +inception_v4 | 224 | Mixed_6e | 983 | 16 | 470 +inception_v4 | 224 | Mixed_6f | 1175 | 16 | 566 +inception_v4 | 224 | Mixed_6g | 1367 | 16 | 662 +inception_v4 | 224 | Mixed_6h | 1559 | 16 | 758 +inception_v4 | 224 | Mixed_7a | 1687 | 32 | 806 +inception_v4 | 224 | Mixed_7b | 1815 | 32 | 870 +inception_v4 | 224 | Mixed_7c | 1943 | 32 | 934 +inception_v4 | 224 | Mixed_7d | 2071 | 32 | 998 +inception_v4 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | 321 | Mixed_3a | 15 | 4 | 2 +inception_v4 | 321 | Mixed_4a | 47 | 4 | 14 +inception_v4 | 321 | Mixed_5a | 55 | 8 | 14 +inception_v4 | 321 | Mixed_5b | 87 | 8 | 30 +inception_v4 | 321 | Mixed_5c | 119 | 8 | 46 +inception_v4 | 321 | Mixed_5d | 151 | 8 | 62 +inception_v4 | 321 | Mixed_5e | 183 | 8 | 78 +inception_v4 | 321 | Mixed_6a | 215 | 16 | 86 +inception_v4 | 321 | Mixed_6b | 407 | 16 | 182 +inception_v4 | 321 | Mixed_6c | 599 | 16 | 278 +inception_v4 | 321 | Mixed_6d | 791 | 16 | 374 +inception_v4 | 321 | Mixed_6e | 983 | 16 | 470 +inception_v4 | 321 | Mixed_6f | 1175 | 16 | 566 +inception_v4 | 321 | Mixed_6g | 1367 | 16 | 662 +inception_v4 | 321 | Mixed_6h | 1559 | 16 | 758 +inception_v4 | 321 | Mixed_7a | 1687 | 32 | 806 +inception_v4 | 321 | Mixed_7b | 1815 | 32 | 870 +inception_v4 | 321 | Mixed_7c | 1943 | 32 | 934 +inception_v4 | 321 | Mixed_7d | 2071 | 32 | 998 +inception_resnet_v2 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | None | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | None | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | None | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | None | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | None | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | None | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | None | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | None | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | None | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | 224 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | 224 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | 224 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | 224 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | 224 | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | 224 | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | 224 | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | 224 | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | 224 | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | 321 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | 321 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | 321 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | 321 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | 321 | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | 321 | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | 321 | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | 321 | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | 321 | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2-same | None | Conv2d_1a_3x3 | 3 | 2 | None +inception_resnet_v2-same | None | Conv2d_2a_3x3 | 7 | 2 | None +inception_resnet_v2-same | None | Conv2d_2b_3x3 | 11 | 2 | None +inception_resnet_v2-same | None | MaxPool_3a_3x3 | 15 | 4 | None +inception_resnet_v2-same | None | Conv2d_3b_1x1 | 15 | 4 | None +inception_resnet_v2-same | None | Conv2d_4a_3x3 | 23 | 4 | None +inception_resnet_v2-same | None | MaxPool_5a_3x3 | 31 | 8 | None +inception_resnet_v2-same | None | Mixed_5b | 63 | 8 | None +inception_resnet_v2-same | None | Mixed_6a | 415 | 16 | None +inception_resnet_v2-same | None | PreAuxLogits | 2335 | 16 | None +inception_resnet_v2-same | None | Mixed_7a | 2399 | 32 | None +inception_resnet_v2-same | None | Conv2d_7b_1x1 | 3039 | 32 | None +inception_resnet_v2-same | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2-same | 224 | Conv2d_2a_3x3 | 7 | 2 | 2 +inception_resnet_v2-same | 224 | Conv2d_2b_3x3 | 11 | 2 | 4 +inception_resnet_v2-same | 224 | MaxPool_3a_3x3 | 15 | 4 | 4 +inception_resnet_v2-same | 224 | Conv2d_3b_1x1 | 15 | 4 | 4 +inception_resnet_v2-same | 224 | Conv2d_4a_3x3 | 23 | 4 | 8 +inception_resnet_v2-same | 224 | MaxPool_5a_3x3 | 31 | 8 | 8 +inception_resnet_v2-same | 224 | Mixed_5b | 63 | 8 | 24 +inception_resnet_v2-same | 224 | Mixed_6a | 415 | 16 | 192 +inception_resnet_v2-same | 224 | PreAuxLogits | 2335 | 16 | 1152 +inception_resnet_v2-same | 224 | Mixed_7a | 2399 | 32 | 1168 +inception_resnet_v2-same | 224 | Conv2d_7b_1x1 | 3039 | 32 | 1488 +inception_resnet_v2-same | 321 | Conv2d_1a_3x3 | 3 | 2 | 1 +inception_resnet_v2-same | 321 | Conv2d_2a_3x3 | 7 | 2 | 3 +inception_resnet_v2-same | 321 | Conv2d_2b_3x3 | 11 | 2 | 5 +inception_resnet_v2-same | 321 | MaxPool_3a_3x3 | 15 | 4 | 7 +inception_resnet_v2-same | 321 | Conv2d_3b_1x1 | 15 | 4 | 7 +inception_resnet_v2-same | 321 | Conv2d_4a_3x3 | 23 | 4 | 11 +inception_resnet_v2-same | 321 | MaxPool_5a_3x3 | 31 | 8 | 15 +inception_resnet_v2-same | 321 | Mixed_5b | 63 | 8 | 31 +inception_resnet_v2-same | 321 | Mixed_6a | 415 | 16 | 207 +inception_resnet_v2-same | 321 | PreAuxLogits | 2335 | 16 | 1167 +inception_resnet_v2-same | 321 | Mixed_7a | 2399 | 32 | 1199 +inception_resnet_v2-same | 321 | Conv2d_7b_1x1 | 3039 | 32 | 1519 +mobilenet_v1 | None | Conv2d_0 | 3 | 2 | None +mobilenet_v1 | None | Conv2d_1_pointwise | 7 | 2 | None +mobilenet_v1 | None | Conv2d_2_pointwise | 11 | 4 | None +mobilenet_v1 | None | Conv2d_3_pointwise | 19 | 4 | None +mobilenet_v1 | None | Conv2d_4_pointwise | 27 | 8 | None +mobilenet_v1 | None | Conv2d_5_pointwise | 43 | 8 | None +mobilenet_v1 | None | Conv2d_6_pointwise | 59 | 16 | None +mobilenet_v1 | None | Conv2d_7_pointwise | 91 | 16 | None +mobilenet_v1 | None | Conv2d_8_pointwise | 123 | 16 | None +mobilenet_v1 | None | Conv2d_9_pointwise | 155 | 16 | None +mobilenet_v1 | None | Conv2d_10_pointwise | 187 | 16 | None +mobilenet_v1 | None | Conv2d_11_pointwise | 219 | 16 | None +mobilenet_v1 | None | Conv2d_12_pointwise | 251 | 32 | None +mobilenet_v1 | None | Conv2d_13_pointwise | 315 | 32 | None +mobilenet_v1 | 224 | Conv2d_0 | 3 | 2 | 0 +mobilenet_v1 | 224 | Conv2d_1_pointwise | 7 | 2 | 2 +mobilenet_v1 | 224 | Conv2d_2_pointwise | 11 | 4 | 2 +mobilenet_v1 | 224 | Conv2d_3_pointwise | 19 | 4 | 6 +mobilenet_v1 | 224 | Conv2d_4_pointwise | 27 | 8 | 6 +mobilenet_v1 | 224 | Conv2d_5_pointwise | 43 | 8 | 14 +mobilenet_v1 | 224 | Conv2d_6_pointwise | 59 | 16 | 14 +mobilenet_v1 | 224 | Conv2d_7_pointwise | 91 | 16 | 30 +mobilenet_v1 | 224 | Conv2d_8_pointwise | 123 | 16 | 46 +mobilenet_v1 | 224 | Conv2d_9_pointwise | 155 | 16 | 62 +mobilenet_v1 | 224 | Conv2d_10_pointwise | 187 | 16 | 78 +mobilenet_v1 | 224 | Conv2d_11_pointwise | 219 | 16 | 94 +mobilenet_v1 | 224 | Conv2d_12_pointwise | 251 | 32 | 94 +mobilenet_v1 | 224 | Conv2d_13_pointwise | 315 | 32 | 126 +mobilenet_v1 | 321 | Conv2d_0 | 3 | 2 | 1 +mobilenet_v1 | 321 | Conv2d_1_pointwise | 7 | 2 | 3 +mobilenet_v1 | 321 | Conv2d_2_pointwise | 11 | 4 | 5 +mobilenet_v1 | 321 | Conv2d_3_pointwise | 19 | 4 | 9 +mobilenet_v1 | 321 | Conv2d_4_pointwise | 27 | 8 | 13 +mobilenet_v1 | 321 | Conv2d_5_pointwise | 43 | 8 | 21 +mobilenet_v1 | 321 | Conv2d_6_pointwise | 59 | 16 | 29 +mobilenet_v1 | 321 | Conv2d_7_pointwise | 91 | 16 | 45 +mobilenet_v1 | 321 | Conv2d_8_pointwise | 123 | 16 | 61 +mobilenet_v1 | 321 | Conv2d_9_pointwise | 155 | 16 | 77 +mobilenet_v1 | 321 | Conv2d_10_pointwise | 187 | 16 | 93 +mobilenet_v1 | 321 | Conv2d_11_pointwise | 219 | 16 | 109 +mobilenet_v1 | 321 | Conv2d_12_pointwise | 251 | 32 | 125 +mobilenet_v1 | 321 | Conv2d_13_pointwise | 315 | 32 | 157 +mobilenet_v1_075 | None | Conv2d_0 | 3 | 2 | None +mobilenet_v1_075 | None | Conv2d_1_pointwise | 7 | 2 | None +mobilenet_v1_075 | None | Conv2d_2_pointwise | 11 | 4 | None +mobilenet_v1_075 | None | Conv2d_3_pointwise | 19 | 4 | None +mobilenet_v1_075 | None | Conv2d_4_pointwise | 27 | 8 | None +mobilenet_v1_075 | None | Conv2d_5_pointwise | 43 | 8 | None +mobilenet_v1_075 | None | Conv2d_6_pointwise | 59 | 16 | None +mobilenet_v1_075 | None | Conv2d_7_pointwise | 91 | 16 | None +mobilenet_v1_075 | None | Conv2d_8_pointwise | 123 | 16 | None +mobilenet_v1_075 | None | Conv2d_9_pointwise | 155 | 16 | None +mobilenet_v1_075 | None | Conv2d_10_pointwise | 187 | 16 | None +mobilenet_v1_075 | None | Conv2d_11_pointwise | 219 | 16 | None +mobilenet_v1_075 | None | Conv2d_12_pointwise | 251 | 32 | None +mobilenet_v1_075 | None | Conv2d_13_pointwise | 315 | 32 | None +mobilenet_v1_075 | 224 | Conv2d_0 | 3 | 2 | 0 +mobilenet_v1_075 | 224 | Conv2d_1_pointwise | 7 | 2 | 2 +mobilenet_v1_075 | 224 | Conv2d_2_pointwise | 11 | 4 | 2 +mobilenet_v1_075 | 224 | Conv2d_3_pointwise | 19 | 4 | 6 +mobilenet_v1_075 | 224 | Conv2d_4_pointwise | 27 | 8 | 6 +mobilenet_v1_075 | 224 | Conv2d_5_pointwise | 43 | 8 | 14 +mobilenet_v1_075 | 224 | Conv2d_6_pointwise | 59 | 16 | 14 +mobilenet_v1_075 | 224 | Conv2d_7_pointwise | 91 | 16 | 30 +mobilenet_v1_075 | 224 | Conv2d_8_pointwise | 123 | 16 | 46 +mobilenet_v1_075 | 224 | Conv2d_9_pointwise | 155 | 16 | 62 +mobilenet_v1_075 | 224 | Conv2d_10_pointwise | 187 | 16 | 78 +mobilenet_v1_075 | 224 | Conv2d_11_pointwise | 219 | 16 | 94 +mobilenet_v1_075 | 224 | Conv2d_12_pointwise | 251 | 32 | 94 +mobilenet_v1_075 | 224 | Conv2d_13_pointwise | 315 | 32 | 126 +mobilenet_v1_075 | 321 | Conv2d_0 | 3 | 2 | 1 +mobilenet_v1_075 | 321 | Conv2d_1_pointwise | 7 | 2 | 3 +mobilenet_v1_075 | 321 | Conv2d_2_pointwise | 11 | 4 | 5 +mobilenet_v1_075 | 321 | Conv2d_3_pointwise | 19 | 4 | 9 +mobilenet_v1_075 | 321 | Conv2d_4_pointwise | 27 | 8 | 13 +mobilenet_v1_075 | 321 | Conv2d_5_pointwise | 43 | 8 | 21 +mobilenet_v1_075 | 321 | Conv2d_6_pointwise | 59 | 16 | 29 +mobilenet_v1_075 | 321 | Conv2d_7_pointwise | 91 | 16 | 45 +mobilenet_v1_075 | 321 | Conv2d_8_pointwise | 123 | 16 | 61 +mobilenet_v1_075 | 321 | Conv2d_9_pointwise | 155 | 16 | 77 +mobilenet_v1_075 | 321 | Conv2d_10_pointwise | 187 | 16 | 93 +mobilenet_v1_075 | 321 | Conv2d_11_pointwise | 219 | 16 | 109 +mobilenet_v1_075 | 321 | Conv2d_12_pointwise | 251 | 32 | 125 +mobilenet_v1_075 | 321 | Conv2d_13_pointwise | 315 | 32 | 157 +resnet_v1_50 | None | resnet_v1_50/block1 | 35 | 8 | None +resnet_v1_50 | None | resnet_v1_50/block2 | 99 | 16 | None +resnet_v1_50 | None | resnet_v1_50/block3 | 291 | 32 | None +resnet_v1_50 | None | resnet_v1_50/block4 | 483 | 32 | None +resnet_v1_50 | 224 | resnet_v1_50/block1 | 35 | 8 | 15 +resnet_v1_50 | 224 | resnet_v1_50/block2 | 99 | 16 | 47 +resnet_v1_50 | 224 | resnet_v1_50/block3 | 291 | 32 | 143 +resnet_v1_50 | 224 | resnet_v1_50/block4 | 483 | 32 | 239 +resnet_v1_50 | 321 | resnet_v1_50/block1 | 35 | 8 | 17 +resnet_v1_50 | 321 | resnet_v1_50/block2 | 99 | 16 | 49 +resnet_v1_50 | 321 | resnet_v1_50/block3 | 291 | 32 | 145 +resnet_v1_50 | 321 | resnet_v1_50/block4 | 483 | 32 | 241 +resnet_v1_101 | None | resnet_v1_101/block1 | 35 | 8 | None +resnet_v1_101 | None | resnet_v1_101/block2 | 99 | 16 | None +resnet_v1_101 | None | resnet_v1_101/block3 | 835 | 32 | None +resnet_v1_101 | None | resnet_v1_101/block4 | 1027 | 32 | None +resnet_v1_101 | 224 | resnet_v1_101/block1 | 35 | 8 | 15 +resnet_v1_101 | 224 | resnet_v1_101/block2 | 99 | 16 | 47 +resnet_v1_101 | 224 | resnet_v1_101/block3 | 835 | 32 | 415 +resnet_v1_101 | 224 | resnet_v1_101/block4 | 1027 | 32 | 511 +resnet_v1_101 | 321 | resnet_v1_101/block1 | 35 | 8 | 17 +resnet_v1_101 | 321 | resnet_v1_101/block2 | 99 | 16 | 49 +resnet_v1_101 | 321 | resnet_v1_101/block3 | 835 | 32 | 417 +resnet_v1_101 | 321 | resnet_v1_101/block4 | 1027 | 32 | 513 +resnet_v1_152 | None | resnet_v1_152/block1 | 35 | 8 | None +resnet_v1_152 | None | resnet_v1_152/block2 | 163 | 16 | None +resnet_v1_152 | None | resnet_v1_152/block3 | 1315 | 32 | None +resnet_v1_152 | None | resnet_v1_152/block4 | 1507 | 32 | None +resnet_v1_152 | 224 | resnet_v1_152/block1 | 35 | 8 | 15 +resnet_v1_152 | 224 | resnet_v1_152/block2 | 163 | 16 | 79 +resnet_v1_152 | 224 | resnet_v1_152/block3 | 1315 | 32 | 655 +resnet_v1_152 | 224 | resnet_v1_152/block4 | 1507 | 32 | 751 +resnet_v1_152 | 321 | resnet_v1_152/block1 | 35 | 8 | 17 +resnet_v1_152 | 321 | resnet_v1_152/block2 | 163 | 16 | 81 +resnet_v1_152 | 321 | resnet_v1_152/block3 | 1315 | 32 | 657 +resnet_v1_152 | 321 | resnet_v1_152/block4 | 1507 | 32 | 753 +resnet_v1_200 | None | resnet_v1_200/block1 | 35 | 8 | None +resnet_v1_200 | None | resnet_v1_200/block2 | 419 | 16 | None +resnet_v1_200 | None | resnet_v1_200/block3 | 1571 | 32 | None +resnet_v1_200 | None | resnet_v1_200/block4 | 1763 | 32 | None +resnet_v1_200 | 224 | resnet_v1_200/block1 | 35 | 8 | 15 +resnet_v1_200 | 224 | resnet_v1_200/block2 | 419 | 16 | 207 +resnet_v1_200 | 224 | resnet_v1_200/block3 | 1571 | 32 | 783 +resnet_v1_200 | 224 | resnet_v1_200/block4 | 1763 | 32 | 879 +resnet_v1_200 | 321 | resnet_v1_200/block1 | 35 | 8 | 17 +resnet_v1_200 | 321 | resnet_v1_200/block2 | 419 | 16 | 209 +resnet_v1_200 | 321 | resnet_v1_200/block3 | 1571 | 32 | 785 +resnet_v1_200 | 321 | resnet_v1_200/block4 | 1763 | 32 | 881 +resnet_v2_50 | None | resnet_v2_50/block1 | 35 | 8 | None +resnet_v2_50 | None | resnet_v2_50/block2 | 99 | 16 | None +resnet_v2_50 | None | resnet_v2_50/block3 | 291 | 32 | None +resnet_v2_50 | None | resnet_v2_50/block4 | 483 | 32 | None +resnet_v2_50 | 224 | resnet_v2_50/block1 | 35 | 8 | 15 +resnet_v2_50 | 224 | resnet_v2_50/block2 | 99 | 16 | 47 +resnet_v2_50 | 224 | resnet_v2_50/block3 | 291 | 32 | 143 +resnet_v2_50 | 224 | resnet_v2_50/block4 | 483 | 32 | 239 +resnet_v2_50 | 321 | resnet_v2_50/block1 | 35 | 8 | 17 +resnet_v2_50 | 321 | resnet_v2_50/block2 | 99 | 16 | 49 +resnet_v2_50 | 321 | resnet_v2_50/block3 | 291 | 32 | 145 +resnet_v2_50 | 321 | resnet_v2_50/block4 | 483 | 32 | 241 +resnet_v2_101 | None | resnet_v2_101/block1 | 35 | 8 | None +resnet_v2_101 | None | resnet_v2_101/block2 | 99 | 16 | None +resnet_v2_101 | None | resnet_v2_101/block3 | 835 | 32 | None +resnet_v2_101 | None | resnet_v2_101/block4 | 1027 | 32 | None +resnet_v2_101 | 224 | resnet_v2_101/block1 | 35 | 8 | 15 +resnet_v2_101 | 224 | resnet_v2_101/block2 | 99 | 16 | 47 +resnet_v2_101 | 224 | resnet_v2_101/block3 | 835 | 32 | 415 +resnet_v2_101 | 224 | resnet_v2_101/block4 | 1027 | 32 | 511 +resnet_v2_101 | 321 | resnet_v2_101/block1 | 35 | 8 | 17 +resnet_v2_101 | 321 | resnet_v2_101/block2 | 99 | 16 | 49 +resnet_v2_101 | 321 | resnet_v2_101/block3 | 835 | 32 | 417 +resnet_v2_101 | 321 | resnet_v2_101/block4 | 1027 | 32 | 513 +resnet_v2_152 | None | resnet_v2_152/block1 | 35 | 8 | None +resnet_v2_152 | None | resnet_v2_152/block2 | 163 | 16 | None +resnet_v2_152 | None | resnet_v2_152/block3 | 1315 | 32 | None +resnet_v2_152 | None | resnet_v2_152/block4 | 1507 | 32 | None +resnet_v2_152 | 224 | resnet_v2_152/block1 | 35 | 8 | 15 +resnet_v2_152 | 224 | resnet_v2_152/block2 | 163 | 16 | 79 +resnet_v2_152 | 224 | resnet_v2_152/block3 | 1315 | 32 | 655 +resnet_v2_152 | 224 | resnet_v2_152/block4 | 1507 | 32 | 751 +resnet_v2_152 | 321 | resnet_v2_152/block1 | 35 | 8 | 17 +resnet_v2_152 | 321 | resnet_v2_152/block2 | 163 | 16 | 81 +resnet_v2_152 | 321 | resnet_v2_152/block3 | 1315 | 32 | 657 +resnet_v2_152 | 321 | resnet_v2_152/block4 | 1507 | 32 | 753 +resnet_v2_200 | None | resnet_v2_200/block1 | 35 | 8 | None +resnet_v2_200 | None | resnet_v2_200/block2 | 419 | 16 | None +resnet_v2_200 | None | resnet_v2_200/block3 | 1571 | 32 | None +resnet_v2_200 | None | resnet_v2_200/block4 | 1763 | 32 | None +resnet_v2_200 | 224 | resnet_v2_200/block1 | 35 | 8 | 15 +resnet_v2_200 | 224 | resnet_v2_200/block2 | 419 | 16 | 207 +resnet_v2_200 | 224 | resnet_v2_200/block3 | 1571 | 32 | 783 +resnet_v2_200 | 224 | resnet_v2_200/block4 | 1763 | 32 | 879 +resnet_v2_200 | 321 | resnet_v2_200/block1 | 35 | 8 | 17 +resnet_v2_200 | 321 | resnet_v2_200/block2 | 419 | 16 | 209 +resnet_v2_200 | 321 | resnet_v2_200/block3 | 1571 | 32 | 785 +resnet_v2_200 | 321 | resnet_v2_200/block4 | 1763 | 32 | 881 + +## FAQ + +### What does a resolution of 'None' mean? + +In this case, the input resolution is undefined. For most models, the receptive +field parameters can be computed even without knowing the input resolution. + +### For some networks, effective_padding shows as 'None' (eg, for Inception_v2 or Mobilenet_v1 when input size is not specified). Why is that? + +This means that the padding for these networks depends on the input size. So, +unless we know exactly the input image dimensionality to be used, it is not +possible to determine the padding applied at the different layers. Look at the +other entries where the input size is fixed; for those cases, effective_padding +is not None. + +This happens due to Tensorflow's implementation of the 'SAME' padding mode, +which may depend on the input feature map size to a given layer. For background +on this, see [these notes from the TF +documentation](https://www.tensorflow.org/versions/master/api_guides/python/nn#Notes_on_SAME_Convolution_Padding). + +Also, note that in this case the program is not able to check if the network is +aligned (ie, it could be that the different paths from input to output have +receptive fields which are not consistently centered at the same position in the +input image). + +So you should be aware that such networks might not be aligned -- the program +has no way of checking it when the padding cannot be determined. + +### The receptive field parameters for network X seem different from what I expected... maybe your calculation is incorrect? + +First, note that the results presented here are based on the tensorflow +implementations from the [TF-Slim model +library](https://github.com/tensorflow/models/tree/master/research/slim). + +So, it is possible that due to some implementation details the RF parameters are +different. + +One common case of confusion is the TF-Slim Resnet implementation, which applies +stride in the last residual unit of each block, instead of at the input +activations in the first residual unit of each block (which is what is described +in the Resnet paper) -- see [this +comment](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_utils.py#L30). +This makes the stride with respect to each convolution block potentially +different. In this case, though, note that a +[flag](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v1.py#L150) +may be used to recover the original striding convention. + +Second, it could be that we have a bug somewhere. While we include [many +tests](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py) +in our library, it is always possible that we missed something. If you suspect +this is happening, please file a GitHub issue +[here](https://github.com/tensorflow/tensorflow/issues). diff --git a/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py b/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py new file mode 100644 index 0000000000000000000000000000000000000000..4495d74bbf66fa461a05f38b430dd404d7da4b08 --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py @@ -0,0 +1,82 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simple script to convert CSV output from rf_benchmark to Markdown format. + +The input CSV should have the following fields: +- CNN +- input resolution +- end_point +- RF size hor +- RF size ver +- effective stride hor +- effective stride ver +- effective padding hor +- effective padding ver + +Since usually in all cases the parameters in the horizontal and vertical +directions are the same, this is assumed by this script, which only prints one +of them to the Markdown file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import csv +import sys + +from tensorflow.python.platform import app + +cmd_args = None + + +def main(unused_argv): + with open(cmd_args.markdown_path, 'w') as f: + # Write table header and field size. + f.write('CNN | resolution | end-point | RF | effective stride | ' + 'effective padding|\n') + f.write( + ':--------------------: | :----------: | :---------------: | :-----: |' + ' :----: | :----:|\n') + with open(cmd_args.csv_path) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + # Make sure horizontal and parameters are the same. + assert row['RF size hor'] == row['RF size ver'] + assert row['effective stride hor'] == row['effective stride ver'] + assert row['effective padding hor'] == row['effective padding ver'] + + f.write('%s|%s|%s|%s|%s|%s\n' % + (row['CNN'], row['input resolution'], row['end_point'], + row['RF size hor'], row['effective stride hor'], + row['effective padding hor'])) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--csv_path', + type=str, + default='/tmp/rf.csv', + help='Path where CSV output of rf_benchmark was saved.') + parser.add_argument( + '--markdown_path', + type=str, + default='/tmp/rf.md', + help='Path where Markdown output will be saved.') + cmd_args, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py index bc383a803496380aaba4d0248d2b7f93253b2b50..0e3c46f17d2e2a277418d39e31927db73a509670 100644 --- a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py +++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py @@ -27,7 +27,7 @@ from tensorflow.python.platform import tf_logging as logging _UNCHANGED_RF_LAYER_OPS = [ "Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor", "FusedBatchNorm", "Identity", "Log", "Mul", "Pow", "RealDiv", "Relu", - "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2" + "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2", "LRN" ] # Different ways in which padding modes may be spelled. diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py index cf55da27236d17c709cbde689831ad68da9a8a7b..a42bbca61135a5c1666f1964c25af9c105b472bb 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py @@ -385,7 +385,7 @@ class ReceptiveFieldTest(test.TestCase): effective_stride_y, effective_padding_x, effective_padding_y) = ( receptive_field.compute_receptive_field_from_graph_def( graph_def, input_node, output_node, - ['Dropout/dropout/random_uniform'])) + ['Dropout/dropout_1/random_uniform'])) self.assertEqual(receptive_field_x, 3) self.assertEqual(receptive_field_y, 3) self.assertEqual(effective_stride_x, 4) diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 43c0f7595590802aa80e1012967d377a6ab83d29..4eb5c920b3517a8968ff730003e786ae2a9c9e26 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -193,6 +193,10 @@ tf_py_test( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = [ + "manual", + "notap", + ], ) cuda_py_tests( diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index d41fc0b3ac1cee4eacc88cb0f41df1f9ee59e7c3..b8840a8f2420f1bc6c75f0a02e5465c595378dec 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import functools +import os import numpy as np @@ -30,6 +31,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -39,6 +41,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=protected-access Linear = core_rnn_cell._Linear # pylint: disable=invalid-name @@ -189,6 +192,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(cell.dtype, None) self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) + cell.get_config() # Should not throw an error g, out_m = cell(x, m) # Layer infers the input type. self.assertEqual(cell.dtype, dtype.name) @@ -439,6 +443,26 @@ class RNNCellTest(test.TestCase): self.assertTrue( float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) + @test_util.run_in_graph_and_eager_modes() + def testWrapperCheckpointing(self): + for wrapper_type in [ + rnn_cell_impl.DropoutWrapper, + rnn_cell_impl.ResidualWrapper, + lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: + with self.test_session(): + cell = rnn_cell_impl.BasicRNNCell(1) + wrapper = wrapper_type(cell) + wrapper(array_ops.ones([1, 1]), + state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) + self.evaluate([v.initializer for v in cell.variables]) + checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(cell._bias.assign([40.])) + save_path = checkpoint.save(prefix) + self.evaluate(cell._bias.assign([0.])) + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([40.], self.evaluate(cell._bias)) + def testOutputProjectionWrapper(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -483,7 +507,13 @@ class RNNCellTest(test.TestCase): base_cell = rnn_cell_impl.GRUCell(3) g, m_new = base_cell(x, m) variable_scope.get_variable_scope().reuse_variables() - g_res, m_new_res = rnn_cell_impl.ResidualWrapper(base_cell)(x, m) + wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell) + (name, dep), = wrapper_object._checkpoint_dependencies + wrapper_object.get_config() # Should not throw an error + self.assertIs(dep, base_cell) + self.assertEqual("cell", name) + + g_res, m_new_res = wrapper_object(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([g, g_res, m_new, m_new_res], { x: np.array([[1., 1., 1.]]), @@ -526,7 +556,13 @@ class RNNCellTest(test.TestCase): "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) m = array_ops.zeros([1, 3]) - cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/cpu:14159") + wrapped = rnn_cell_impl.GRUCell(3) + cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159") + (name, dep), = cell._checkpoint_dependencies + cell.get_config() # Should not throw an error + self.assertIs(dep, wrapped) + self.assertEqual("cell", name) + outputs, _ = cell(x, m) self.assertTrue("cpu:14159" in outputs.device.lower()) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index ba4933ddf793c58d00ae28a54eb21410f41e2e16..be99a5d67a3e49b1d522406601d050392f75e963 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import state_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib @@ -142,6 +143,47 @@ class TestStateSaver(object): self.saved_state[name] = state return array_ops.identity(state) + @property + def batch_size(self): + return self._batch_size + + @property + def state_size(self): + return self._state_size + + +class TestStateSaverWithCounters(TestStateSaver): + """Class wrapper around TestStateSaver. + + A dummy class used for testing of static_state_saving_rnn. It helps test if + save_state and state functions got called same number of time when we + evaluate output of rnn cell and state or either of them separately. It + inherits from the TestStateSaver and adds the counters for calls of functions. + """ + + def __init__(self, batch_size, state_size): + super(TestStateSaverWithCounters, self).__init__(batch_size, state_size) + self._num_state_calls = variables_lib.Variable(0) + self._num_save_state_calls = variables_lib.Variable(0) + + def state(self, name): + with ops_lib.control_dependencies( + [state_ops.assign_add(self._num_state_calls, 1)]): + return super(TestStateSaverWithCounters, self).state(name) + + def save_state(self, name, state): + with ops_lib.control_dependencies([state_ops.assign_add( + self._num_save_state_calls, 1)]): + return super(TestStateSaverWithCounters, self).save_state(name, state) + + @property + def num_state_calls(self): + return self._num_state_calls + + @property + def num_save_state_calls(self): + return self._num_save_state_calls + class RNNTest(test.TestCase): @@ -186,6 +228,9 @@ class RNNTest(test.TestCase): cell = Plus1RNNCell() full_dropout_cell = rnn_cell.DropoutWrapper( cell, input_keep_prob=1e-12, seed=0) + (name, dep), = full_dropout_cell._checkpoint_dependencies + self.assertIs(dep, cell) + self.assertEqual("cell", name) batch_size = 2 input_size = 5 max_length = 8 @@ -1792,13 +1837,40 @@ class StateSaverRNNTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) - def _testScope(self, factory, prefix="prefix", use_outer_scope=True): + def _factory(self, scope, state_saver): + num_units = state_saver.state_size // 2 + batch_size = state_saver.batch_size + input_size = 5 + max_length = 8 + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, seed=self._seed) + cell = rnn_cell.LSTMCell( + num_units, + use_peepholes=False, + initializer=initializer, + state_is_tuple=False) + inputs = max_length * [ + array_ops.zeros(dtype=dtypes.float32, shape=(batch_size, input_size)) + ] + out, state = rnn.static_state_saving_rnn( + cell, + inputs, + state_saver=state_saver, + state_name="save_lstm", + scope=scope) + return out, state, state_saver + + def _testScope(self, prefix="prefix", use_outer_scope=True): + num_units = 3 + batch_size = 2 + state_saver = TestStateSaver(batch_size, 2 * num_units) + with self.test_session(use_gpu=True, graph=ops_lib.Graph()): if use_outer_scope: with variable_scope.variable_scope(prefix) as scope: - factory(scope) + self._factory(scope=scope, state_saver=state_saver) else: - factory(prefix) + self._factory(scope=prefix, state_saver=state_saver) variables_lib.global_variables_initializer() # check that all the variables names starts @@ -1813,34 +1885,46 @@ class StateSaverRNNTest(test.TestCase): self.assertEqual(len(scope_vars), len(all_vars)) def testStateSaverRNNScope(self): - num_units = 3 - input_size = 5 - batch_size = 2 - max_length = 8 + self._testScope(use_outer_scope=True) + self._testScope(use_outer_scope=False) + self._testScope(prefix=None, use_outer_scope=False) - def factory(scope): - initializer = init_ops.random_uniform_initializer( - -0.01, 0.01, seed=self._seed) - state_saver = TestStateSaver(batch_size, 2 * num_units) - cell = rnn_cell.LSTMCell( - num_units, - use_peepholes=False, - initializer=initializer, - state_is_tuple=False) - inputs = max_length * [ - array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) - ] - return rnn.static_state_saving_rnn( - cell, - inputs, - state_saver=state_saver, - state_name="save_lstm", - scope=scope) + def testStateSaverCallsSaveState(self): + """Test that number of calls to state and save_state is equal. - self._testScope(factory, use_outer_scope=True) - self._testScope(factory, use_outer_scope=False) - self._testScope(factory, prefix=None, use_outer_scope=False) + Test if the order of actual evaluating or skipping evaluation of out, + state tensors, which are the output tensors from static_state_saving_rnn, + have influence on number of calls to save_state and state methods of + state_saver object (the number of calls should be same.) + """ + num_units = 3 + batch_size = 2 + state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units) + out, state, state_saver = self._factory(scope=None, state_saver=state_saver) + + with self.test_session() as sess: + sess.run(variables_lib.global_variables_initializer()) + sess.run(variables_lib.local_variables_initializer()) + + _, _, num_state_calls, num_save_state_calls = sess.run([ + out, + state, + state_saver.num_state_calls, + state_saver.num_save_state_calls]) + self.assertEqual(num_state_calls, num_save_state_calls) + + _, num_state_calls, num_save_state_calls = sess.run([ + out, + state_saver.num_state_calls, + state_saver.num_save_state_calls]) + self.assertEqual(num_state_calls, num_save_state_calls) + + _, num_state_calls, num_save_state_calls = sess.run([ + state, + state_saver.num_state_calls, + state_saver.num_save_state_calls]) + self.assertEqual(num_state_calls, num_save_state_calls) class GRUTest(test.TestCase): diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index fdecceff526a860a274354e53e824b98d11418a6..6bd58c4d322c04d4d14d04678e24a05c0f876208 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -1,4 +1,4 @@ -package(default_visibility = ["//tensorflow:__subpackages__"]) +package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/signal/python/ops/window_ops.py b/tensorflow/contrib/signal/python/ops/window_ops.py index 50094010dc75cf8b3c62da5e3a7ed5e995e6df41..59e67e8ba414df1f9c777d1f5a3f3dba975648a2 100644 --- a/tensorflow/contrib/signal/python/ops/window_ops.py +++ b/tensorflow/contrib/signal/python/ops/window_ops.py @@ -47,7 +47,7 @@ def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None): Raises: ValueError: If `dtype` is not a floating point type. - [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window + [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows """ return _raised_cosine_window(name, 'hann_window', window_length, periodic, dtype, 0.5, 0.5) @@ -72,7 +72,7 @@ def hamming_window(window_length, periodic=True, dtype=dtypes.float32, Raises: ValueError: If `dtype` is not a floating point type. - [hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window + [hamming]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows """ return _raised_cosine_window(name, 'hamming_window', window_length, periodic, dtype, 0.54, 0.46) diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md index a688f0f2803a2f0bb397f4b8d3a74059fc7e37af..f2bb458848fab5603128903868b52f29785efc92 100644 --- a/tensorflow/contrib/slim/README.md +++ b/tensorflow/contrib/slim/README.md @@ -912,5 +912,5 @@ Sergio Guadarrama and Nathan Silberman ## Citation "TensorFlow-Slim: a lightweight library for defining, training and evaluating complex models in TensorFlow" -S. Guadarrama, N. Silberman -https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim 2016 +S. Guadarrama, N. Silberman, 2016. +https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index 94fc12ca814721acf62f16b72ffa50473043cc8b..3d0308aaf3da3b5b16fd22a2905db36917e8c97b 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -26,7 +26,6 @@ import time import numpy as np from tensorflow.contrib.framework.python.ops import variables as variables_lib -from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.contrib.slim.python.slim import evaluation from tensorflow.contrib.training.python.training import evaluation as evaluation_lib from tensorflow.core.protobuf import saver_pb2 @@ -37,6 +36,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics from tensorflow.python.ops import variables from tensorflow.python.platform import flags from tensorflow.python.platform import gfile @@ -89,8 +89,8 @@ class EvaluationTest(test.TestCase): self._predictions, self._scale = TestModel(self._inputs) def testFinalOpsOnEvaluationLoop(self): - value_op, update_op = metric_ops.streaming_accuracy(self._predictions, - self._labels) + value_op, update_op = metrics.accuracy( + labels=self._labels, predictions=self._predictions) init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) # Create checkpoint and log directories: @@ -136,9 +136,10 @@ class EvaluationTest(test.TestCase): self.assertTrue(obj.hook_was_run) def _create_names_to_metrics(self, predictions, labels): - accuracy0, update_op0 = metric_ops.streaming_accuracy(predictions, labels) - accuracy1, update_op1 = metric_ops.streaming_accuracy(predictions + 1, - labels) + accuracy0, update_op0 = metrics.accuracy( + labels=labels, predictions=predictions) + accuracy1, update_op1 = metrics.accuracy( + labels=labels, predictions=predictions + 1) names_to_values = {'Accuracy': accuracy0, 'Another_accuracy': accuracy1} names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1} @@ -198,8 +199,8 @@ class EvaluationTest(test.TestCase): predictions_limited = input.limit_epochs(self._predictions, num_epochs=1) labels_limited = input.limit_epochs(self._labels, num_epochs=1) - value_op, update_op = metric_ops.streaming_accuracy( - predictions_limited, labels_limited) + value_op, update_op = metrics.accuracy( + labels=labels_limited, predictions=predictions_limited) init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) @@ -260,8 +261,8 @@ class SingleEvaluationTest(test.TestCase): self._prepareCheckpoint(checkpoint_path) # Next, determine the metric to evaluate: - value_op, update_op = metric_ops.streaming_accuracy(self._predictions, - self._labels) + value_op, update_op = metrics.accuracy( + labels=self._labels, predictions=self._predictions) # Run the evaluation and verify the results: accuracy_value = evaluation.evaluate_once( @@ -276,8 +277,8 @@ class SingleEvaluationTest(test.TestCase): self._prepareCheckpoint(checkpoint_path) # Next, determine the metric to evaluate: - value_op, update_op = metric_ops.streaming_accuracy(self._predictions, - self._labels) + value_op, update_op = metrics.accuracy( + labels=self._labels, predictions=self._predictions) dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir') dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False) diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py index 9305c6a11c4ec898c82553773e8e7277a54ab82e..85918bf8506623cf5e0c9106ae9ed80e233f5a7d 100644 --- a/tensorflow/contrib/solvers/python/ops/linear_equations.py +++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import linalg_ops def conjugate_gradient(operator, diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD index b729fff261192be22c6a56fa9ca0a641f302c570..d7ba754f701d4b433e35ad8396eae7ee6132b97f 100644 --- a/tensorflow/contrib/sparsemax/BUILD +++ b/tensorflow/contrib/sparsemax/BUILD @@ -38,7 +38,7 @@ py_library( cuda_py_tests( name = "sparsemax_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/sparsemax_test.py"], additional_deps = [ ":sparsemax_py", diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD index 30be14c10cd8576ded75b8489cc89d439a9cc282..0b8fc0cdc66ae41807cce92776ada263675b1f94 100644 --- a/tensorflow/contrib/stat_summarizer/BUILD +++ b/tensorflow/contrib/stat_summarizer/BUILD @@ -31,5 +31,8 @@ tf_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:variables", ], - tags = ["no_windows"], + tags = [ + "no_windows", + "notap", # TODO(b/80546574): test is flaky + ], ) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index 99ced53e1167ec5486d0b75cff81ffbf857c2be7..d22b80ac88a9ced541a952fcbb58c50366464075 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -21,6 +21,7 @@ from @{tf.summary.merge_all} to @{tf.summary.FileWriter}. To use with eager execution enabled, write your code as follows: +```python global_step = tf.train.get_or_create_global_step() summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=10000) @@ -30,9 +31,11 @@ with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): tf.contrib.summary.scalar("loss", my_loss) # In this case every call to tf.contrib.summary.scalar will generate a record # ... +``` To use it with graph execution, write your code as follows: +```python global_step = tf.train.get_or_create_global_step() summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=10000) @@ -53,7 +56,7 @@ with tf.Session(...) as sess: while not_done_training: sess.run([train_op, tf.contrib.summary.all_summary_ops()]) # ... - +``` """ from __future__ import absolute_import diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index e893e1d1c836cc7feef15757dde79d0db362cbaf..d8236a0a6fa6d0d0e383e454eb0146bb10b6f49d 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -21,10 +21,10 @@ import numpy as np from tensorflow.contrib import losses from tensorflow.contrib.learn.python.learn.estimators import prediction_key -from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics from tensorflow.python.ops import nn INFERENCE_PROB_NAME = prediction_key.PredictionKey.PROBABILITIES @@ -38,12 +38,13 @@ def _top_k_generator(k): targets = math_ops.to_int32(targets) if targets.get_shape().ndims > 1: targets = array_ops.squeeze(targets, axis=[1]) - return metric_ops.streaming_mean(nn.in_top_k(probabilities, targets, k)) + return metrics.mean(nn.in_top_k(probabilities, targets, k)) return _top_k def _accuracy(predictions, targets, weights=None): - return metric_ops.streaming_accuracy(predictions, targets, weights=weights) + return metrics.accuracy( + labels=targets, predictions=predictions, weights=weights) def _r2(probabilities, targets, weights=None): @@ -53,7 +54,7 @@ def _r2(probabilities, targets, weights=None): squares_residuals = math_ops.reduce_sum( math_ops.square(targets - probabilities), 0) score = 1 - math_ops.reduce_sum(squares_residuals / squares_total) - return metric_ops.streaming_mean(score, weights=weights) + return metrics.mean(score, weights=weights) def _squeeze_and_onehot(targets, depth): @@ -62,7 +63,7 @@ def _squeeze_and_onehot(targets, depth): def _sigmoid_entropy(probabilities, targets, weights=None): - return metric_ops.streaming_mean( + return metrics.mean( losses.sigmoid_cross_entropy(probabilities, _squeeze_and_onehot( targets, @@ -71,7 +72,7 @@ def _sigmoid_entropy(probabilities, targets, weights=None): def _softmax_entropy(probabilities, targets, weights=None): - return metric_ops.streaming_mean( + return metrics.mean( losses.sparse_softmax_cross_entropy(probabilities, math_ops.to_int32(targets)), weights=weights) @@ -82,7 +83,7 @@ def _predictions(predictions, unused_targets, **unused_kwargs): def _class_log_loss(probabilities, targets, weights=None): - return metric_ops.streaming_mean( + return metrics.mean( losses.log_loss(probabilities, _squeeze_and_onehot(targets, array_ops.shape(probabilities)[1])), @@ -90,34 +91,36 @@ def _class_log_loss(probabilities, targets, weights=None): def _precision(predictions, targets, weights=None): - return metric_ops.streaming_precision(predictions, targets, weights=weights) + return metrics.precision( + labels=targets, predictions=predictions, weights=weights) def _precision_at_thresholds(predictions, targets, weights=None): - return metric_ops.streaming_precision_at_thresholds( - array_ops.slice(predictions, [0, 1], [-1, 1]), - targets, - np.arange( - 0, 1, 0.01, dtype=np.float32), + return metrics.precision_at_thresholds( + labels=targets, + predictions=array_ops.slice(predictions, [0, 1], [-1, 1]), + thresholds=np.arange(0, 1, 0.01, dtype=np.float32), weights=weights) def _recall(predictions, targets, weights=None): - return metric_ops.streaming_recall(predictions, targets, weights=weights) + return metrics.recall( + labels=targets, predictions=predictions, weights=weights) def _recall_at_thresholds(predictions, targets, weights=None): - return metric_ops.streaming_recall_at_thresholds( - array_ops.slice(predictions, [0, 1], [-1, 1]), - targets, - np.arange( - 0, 1, 0.01, dtype=np.float32), + return metrics.recall_at_thresholds( + labels=targets, + predictions=array_ops.slice(predictions, [0, 1], [-1, 1]), + thresholds=np.arange(0, 1, 0.01, dtype=np.float32), weights=weights) def _auc(probs, targets, weights=None): - return metric_ops.streaming_auc(array_ops.slice(probs, [0, 1], [-1, 1]), - targets, weights=weights) + return metrics.auc( + labels=targets, + predictions=array_ops.slice(probs, [0, 1], [-1, 1]), + weights=weights) _EVAL_METRICS = { diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 7a35a70bbe3112e0649cefd8116cc50565978da5..6f62cd11a9733949c350e35b6b0c436dd097cc33 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -295,7 +295,7 @@ def get_epoch_variable(): # A simple container to hold the training variables for a single tree. -class TreeTrainingVariables(object): +class TreeVariables(object): """Stores tf.Variables for training a single random tree. Uses tf.get_variable to get tree-specific names so that this can be used @@ -303,7 +303,7 @@ class TreeTrainingVariables(object): then relies on restoring that model to evaluate). """ - def __init__(self, params, tree_num, training): + def __init__(self, params, tree_num, training, tree_config='', tree_stat=''): if (not hasattr(params, 'params_proto') or not isinstance(params.params_proto, _params_proto.TensorForestParams)): @@ -315,27 +315,28 @@ class TreeTrainingVariables(object): # TODO(gilberth): Manually shard this to be able to fit it on # multiple machines. self.stats = stats_ops.fertile_stats_variable( - params, '', self.get_tree_name('stats', tree_num)) + params, tree_stat, self.get_tree_name('stats', tree_num)) self.tree = model_ops.tree_variable( - params, '', self.stats, self.get_tree_name('tree', tree_num)) + params, tree_config, self.stats, self.get_tree_name('tree', tree_num)) def get_tree_name(self, name, num): return '{0}-{1}'.format(name, num) -class ForestTrainingVariables(object): +class ForestVariables(object): """A container for a forests training data, consisting of multiple trees. - Instantiates a TreeTrainingVariables object for each tree. We override the + Instantiates a TreeVariables object for each tree. We override the __getitem__ and __setitem__ function so that usage looks like this: - forest_variables = ForestTrainingVariables(params) + forest_variables = ForestVariables(params) ... forest_variables.tree ... """ def __init__(self, params, device_assigner, training=True, - tree_variables_class=TreeTrainingVariables): + tree_variables_class=TreeVariables, + tree_configs=None, tree_stats=None): self.variables = [] # Set up some scalar variables to run through the device assigner, then # we can use those to colocate everything related to a tree. @@ -347,7 +348,13 @@ class ForestTrainingVariables(object): for i in range(params.num_trees): with ops.device(self.device_dummies[i].device): - self.variables.append(tree_variables_class(params, i, training)) + kwargs = {} + if tree_configs is not None: + kwargs.update(dict(tree_config=tree_configs[i])) + if tree_stats is not None: + kwargs.update(dict(tree_stat=tree_stats[i])) + self.variables.append(tree_variables_class( + params, i, training, **kwargs)) def __setitem__(self, t, val): self.variables[t] = val @@ -361,9 +368,11 @@ class RandomForestGraphs(object): def __init__(self, params, + tree_configs=None, + tree_stats=None, device_assigner=None, variables=None, - tree_variables_class=TreeTrainingVariables, + tree_variables_class=TreeVariables, tree_graphs=None, training=True): self.params = params @@ -371,9 +380,10 @@ class RandomForestGraphs(object): device_assigner or framework_variables.VariableDeviceChooser()) logging.info('Constructing forest with params = ') logging.info(self.params.__dict__) - self.variables = variables or ForestTrainingVariables( + self.variables = variables or ForestVariables( self.params, device_assigner=self.device_assigner, training=training, - tree_variables_class=tree_variables_class) + tree_variables_class=tree_variables_class, + tree_configs=tree_configs, tree_stats=tree_stats) tree_graph_class = tree_graphs or RandomTreeGraphs self.trees = [ tree_graph_class(self.variables[i], self.params, i) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index bbe627b15773fafe83a0700da696f429876c0968..1c9c81827e0f251c8ae7bc47242334fb202835ac 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -18,10 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from google.protobuf.json_format import ParseDict +from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.ops import resources +from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -110,6 +114,47 @@ class TensorForestTest(test_util.TensorFlowTestCase): self.assertTrue(isinstance(paths, ops.Tensor)) self.assertTrue(isinstance(var, ops.Tensor)) + def testInfrenceFromRestoredModel(self): + input_data = [[-1., 0.], [-1., 2.], # node 1 + [1., 0.], [1., -2.]] # node 2 + expected_prediction = [[0.0, 1.0], [0.0, 1.0], + [0.0, 1.0], [0.0, 1.0]] + hparams = tensor_forest.ForestHParams( + num_classes=2, + num_features=2, + num_trees=1, + max_nodes=1000, + split_after_samples=25).fill() + tree_weight = {'decisionTree': + {'nodes': + [{'binaryNode': + {'rightChildId': 2, + 'leftChildId': 1, + 'inequalityLeftChildTest': + {'featureId': {'id': '0'}, + 'threshold': {'floatValue': 0}}}}, + {'leaf': {'vector': + {'value': [{'floatValue': 0.0}, + {'floatValue': 1.0}]}}, + 'nodeId': 1}, + {'leaf': {'vector': + {'value': [{'floatValue': 0.0}, + {'floatValue': 1.0}]}}, + 'nodeId': 2}]}} + restored_tree_param = ParseDict(tree_weight, + _tree_proto.Model()).SerializeToString() + graph_builder = tensor_forest.RandomForestGraphs(hparams, + [restored_tree_param]) + probs, paths, var = graph_builder.inference_graph(input_data) + self.assertTrue(isinstance(probs, ops.Tensor)) + self.assertTrue(isinstance(paths, ops.Tensor)) + self.assertTrue(isinstance(var, ops.Tensor)) + with self.test_session(): + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + self.assertEquals(probs.eval().shape, (4, 2)) + self.assertEquals(probs.eval().tolist(), expected_prediction) + def testTrainingConstructionClassificationSparse(self): input_data = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]], diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index 630c0607ae21d0276a9dd0507346d5dc4ed9f4a9..cfdc884277a025aa11995d329389f3748b17490c 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include + #include "tensorflow/contrib/tensorboard/db/summary_converter.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -66,14 +68,9 @@ const char* kImagePluginName = "images"; const char* kAudioPluginName = "audio"; const char* kHistogramPluginName = "histograms"; -const int kScalarSlots = 10000; -const int kImageSlots = 10; -const int kAudioSlots = 10; -const int kHistogramSlots = 1; -const int kTensorSlots = 10; - const int64 kReserveMinBytes = 32; const double kReserveMultiplier = 1.5; +const int64 kPreallocateRows = 1000; // Flush is a misnomer because what we're actually doing is having lots // of commits inside any SqliteTransaction that writes potentially @@ -139,22 +136,6 @@ void PatchPluginName(SummaryMetadata* metadata, const char* name) { } } -int GetSlots(const Tensor& t, const SummaryMetadata& metadata) { - if (metadata.plugin_data().plugin_name() == kScalarPluginName) { - return kScalarSlots; - } else if (metadata.plugin_data().plugin_name() == kImagePluginName) { - return kImageSlots; - } else if (metadata.plugin_data().plugin_name() == kAudioPluginName) { - return kAudioSlots; - } else if (metadata.plugin_data().plugin_name() == kHistogramPluginName) { - return kHistogramSlots; - } else if (t.dims() == 0 && t.dtype() != DT_STRING) { - return kScalarSlots; - } else { - return kTensorSlots; - } -} - Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) { const char* sql = R"sql( INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?) @@ -481,24 +462,6 @@ class RunMetadata { return insert.StepAndReset(); } - Status GetIsWatching(Sqlite* db, bool* is_watching) - SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { - mutex_lock lock(mu_); - if (experiment_id_ == kAbsent) { - *is_watching = true; - return Status::OK(); - } - const char* sql = R"sql( - SELECT is_watching FROM Experiments WHERE experiment_id = ? - )sql"; - SqliteStatement stmt; - TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); - stmt.BindInt(1, experiment_id_); - TF_RETURN_IF_ERROR(stmt.StepOnce()); - *is_watching = stmt.ColumnInt(0) != 0; - return Status::OK(); - } - private: Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (user_id_ != kAbsent || user_name_.empty()) return Status::OK(); @@ -659,43 +622,15 @@ class RunMetadata { /// \brief Tensor writer for a single series, e.g. Tag. /// -/// This class can be used to write an infinite stream of Tensors to the -/// database in a fixed block of contiguous disk space. This is -/// accomplished using Algorithm R reservoir sampling. -/// -/// The reservoir consists of a fixed number of rows, which are inserted -/// using ZEROBLOB upon receiving the first sample, which is used to -/// predict how big the other ones are likely to be. This is done -/// transactionally in a way that tries to be mindful of other processes -/// that might be trying to access the same DB. -/// -/// Once the reservoir fills up, rows are replaced at random, and writes -/// gradually become no-ops. This allows long training to go fast -/// without configuration. The exception is when someone is actually -/// looking at TensorBoard. When that happens, the "keep last" behavior -/// is turned on and Append() will always result in a write. -/// -/// If no one is watching training, this class still holds on to the -/// most recent "dangling" Tensor, so if Finish() is called, the most -/// recent training state can be written to disk. -/// -/// The randomly selected sampling points should be consistent across -/// multiple instances. -/// /// This class is thread safe. class SeriesWriter { public: - SeriesWriter(int64 series, int slots, RunMetadata* meta) - : series_{series}, - slots_{slots}, - meta_{meta}, - rng_{std::mt19937_64::default_seed} { + SeriesWriter(int64 series, RunMetadata* meta) : series_{series}, meta_{meta} { DCHECK(series_ > 0); - DCHECK(slots_ > 0); } Status Append(Sqlite* db, int64 step, uint64 now, double computed_time, - Tensor t) SQLITE_TRANSACTIONS_EXCLUDED(*db) + const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); if (rowids_.empty()) { @@ -705,41 +640,20 @@ class SeriesWriter { return s; } } - DCHECK(rowids_.size() == slots_); - int64 rowid; - size_t i = count_; - if (i < slots_) { - rowid = last_rowid_ = rowids_[i]; - } else { - i = rng_() % (i + 1); - if (i < slots_) { - rowid = last_rowid_ = rowids_[i]; - } else { - bool keep_last; - TF_RETURN_IF_ERROR(meta_->GetIsWatching(db, &keep_last)); - if (!keep_last) { - ++count_; - dangling_tensor_.reset(new Tensor(std::move(t))); - dangling_step_ = step; - dangling_computed_time_ = computed_time; - return Status::OK(); - } - rowid = last_rowid_; - } - } + int64 rowid = rowids_.front(); Status s = Write(db, rowid, step, computed_time, t); if (s.ok()) { ++count_; - dangling_tensor_.reset(); } + rowids_.pop_front(); return s; } Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); - // Short runs: Delete unused pre-allocated Tensors. - if (count_ < rowids_.size()) { + // Delete unused pre-allocated Tensors. + if (!rowids_.empty()) { SqliteTransaction txn(*db); const char* sql = R"sql( DELETE FROM Tensors WHERE rowid = ? @@ -747,19 +661,13 @@ class SeriesWriter { SqliteStatement deleter; TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter)); for (size_t i = count_; i < rowids_.size(); ++i) { - deleter.BindInt(1, rowids_[i]); + deleter.BindInt(1, rowids_.front()); TF_RETURN_IF_ERROR(deleter.StepAndReset()); + rowids_.pop_front(); } TF_RETURN_IF_ERROR(txn.Commit()); rowids_.clear(); } - // Long runs: Make last sample be the very most recent one. - if (dangling_tensor_) { - DCHECK(last_rowid_ != kAbsent); - TF_RETURN_IF_ERROR(Write(db, last_rowid_, dangling_step_, - dangling_computed_time_, *dangling_tensor_)); - dangling_tensor_.reset(); - } return Status::OK(); } @@ -783,7 +691,6 @@ class SeriesWriter { Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t, const StringPiece& data, int64 rowid) { - // TODO(jart): How can we ensure reservoir fills on replace? const char* sql = R"sql( UPDATE OR REPLACE Tensors @@ -878,7 +785,7 @@ class SeriesWriter { // TODO(jart): Maybe preallocate index pages by setting step. This // is tricky because UPDATE OR REPLACE can have a side // effect of deleting preallocated rows. - for (int64 i = 0; i < slots_; ++i) { + for (int64 i = 0; i < kPreallocateRows; ++i) { insert.BindInt(1, series_); insert.BindInt(2, reserved_bytes); TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i); @@ -902,16 +809,10 @@ class SeriesWriter { mutex mu_; const int64 series_; - const int slots_; RunMetadata* const meta_; - std::mt19937_64 rng_ GUARDED_BY(mu_); uint64 count_ GUARDED_BY(mu_) = 0; - int64 last_rowid_ GUARDED_BY(mu_) = kAbsent; - std::vector rowids_ GUARDED_BY(mu_); + std::deque rowids_ GUARDED_BY(mu_); uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0; - std::unique_ptr dangling_tensor_ GUARDED_BY(mu_); - int64 dangling_step_ GUARDED_BY(mu_) = 0; - double dangling_computed_time_ GUARDED_BY(mu_) = 0.0; TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter); }; @@ -928,10 +829,10 @@ class RunWriter { explicit RunWriter(RunMetadata* meta) : meta_{meta} {} Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now, - double computed_time, Tensor t, int slots) + double computed_time, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { - SeriesWriter* writer = GetSeriesWriter(tag_id, slots); - return writer->Append(db, step, now, computed_time, std::move(t)); + SeriesWriter* writer = GetSeriesWriter(tag_id); + return writer->Append(db, step, now, computed_time, t); } Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) @@ -948,11 +849,11 @@ class RunWriter { } private: - SeriesWriter* GetSeriesWriter(int64 tag_id, int slots) LOCKS_EXCLUDED(mu_) { + SeriesWriter* GetSeriesWriter(int64 tag_id) LOCKS_EXCLUDED(mu_) { mutex_lock sl(mu_); auto spot = series_writers_.find(tag_id); if (spot == series_writers_.end()) { - SeriesWriter* writer = new SeriesWriter(tag_id, slots, meta_); + SeriesWriter* writer = new SeriesWriter(tag_id, meta_); series_writers_[tag_id].reset(writer); return writer; } else { @@ -1082,8 +983,7 @@ class SummaryDbWriter : public SummaryWriterInterface { TF_RETURN_IF_ERROR( meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata)); TF_RETURN_WITH_CONTEXT_IF_ERROR( - run_.Append(db_, tag_id, step, now, computed_time, t, - GetSlots(t, metadata)), + run_.Append(db_, tag_id, step, now, computed_time, t), meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(), "/", tag, "@", step); return Status::OK(); @@ -1155,8 +1055,7 @@ class SummaryDbWriter : public SummaryWriterInterface { int64 tag_id; TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t, - GetSlots(t, s->metadata())); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } // TODO(jart): Refactor Summary -> Tensor logic into separate file. @@ -1169,8 +1068,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kScalarPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kScalarSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) { @@ -1201,8 +1099,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kHistogramPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kHistogramSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) { @@ -1216,8 +1113,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kImagePluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kImageSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) { @@ -1230,8 +1126,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kAudioPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kAudioSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Env* const env_; diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index 2044692b6e746bc317843d715fa17ab5ec0bf99d..2e8d4109dd624ab66d774668ad04def9a7d3cdf2 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -189,7 +189,7 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) { ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); int64 user_id = QueryInt("SELECT user_id FROM Users"); int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments"); @@ -238,7 +238,7 @@ TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) { ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); } TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { @@ -255,7 +255,7 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); TF_ASSERT_OK(writer_->Flush()); ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(20000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(2000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'"); int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'"); EXPECT_GT(tag1_id, 0LL); diff --git a/tensorflow/contrib/tensorboard/graph_explorer/proto/graph_explorer.proto b/tensorflow/contrib/tensorboard/graph_explorer/proto/graph_explorer.proto deleted file mode 100644 index 835337ed5c58d0f0595ce8a88f08c8e63a860a36..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorboard/graph_explorer/proto/graph_explorer.proto +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2015 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the 'License'); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an 'AS IS' BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -// GraphExplorer is a tool that supports interactive, hierarchical visualization -// of graphs. GraphExplorer renders graphs generated by TensorFlow represented -// as GraphDef messages defined in tensorflow/core/framework/graph.proto. The -// GraphDef proto does not allow for explicitly specifying visual attributes of -// the graph such as color, line thickness, fonts, etc. This file introduces a -// new proto for representing graphs and specifying visual attributes of graphs. -// -// The structure of the Graph proto is given by the EBNF grammar below. Consult -// the message definitions below for details. -// -// graph ::= node* edge* node_attribute* metanode_attribute* edge_attribute* -// graph_attribute* -// node ::= node_id node_attribute* metanode_attribute* node_data* -// edge ::= source_id target_id edge_attribute* edge_data* -// -// A graph consists of a list of nodes and a list of edges and attributes for -// nodes, edges and the graph. Attributes have a name and a value and are -// represented as key-value pairs, with {"color", "blue"} being an example. -// Attributes have a scope, where the broadest scope is the graph and the -// narrowest is a node that has no internal structure. -syntax = "proto3"; - -package graph_explorer; - -// There are two types of nodes. A 'metanode' contains other -// nodes and a 'leaf node' has no internal structure. The metanode containment -// relationship is acyclic, meaning that if a metanode 'A' contains the metanode -// 'B', then 'B' cannot contain 'A'. -message Node { - // The identifier of a node is a sequence of strings separated by '/'. The - // identifier provides a unique name for a node and defines its hierarchical - // relation to other nodes. If no label is provided the last part of the - // identifier is used as a label. - // - // Example: In the graph below, metanodes are written with square brackets and - // leaf nodes with parentheses. The metanode 'node1' contains the leaf node - // 'node4' and the metanode 'node2', which contains the leaf node 'node3'. - // - // [node1 [node2 (node3)] (node4)] - // - // The identifiers for these nodes are: "node1", "node1/node2", - // "node1/node2/node3", and "node1/node4". - string name = 1; - - // A node attribute is information used by Graph Explorer to style a node. - map node_attr = 2; - - // A metanode attribute is one that is inherited by all nodes inside the - // current metanode. If an attribute applies only to the current node and - // should not be inherited, it should be specified as a node attribute. - map metanode_attr = 3; -}; - -// An edge consists of a source and a target node, specified by their -// identifiers. An edge has attributes and data that are similar to node -// attributes and node data. Edges do not form a hierarchy so there are no -// metanode attributes. -message Edge { - // The source and target fields must have the format of a Node name. - string source = 1; - string target = 2; - - // Edge attributes. - map edge_attr = 3; -} - -message Graph { - // List of nodes in the graph. - repeated Node node = 1; - - // List of edges in the graph. - repeated Edge edge = 2; - - // Default values of node, metanode and edge attributes. - map node_attr = 3; - map metanode_attr = 4; - map edge_attr = 5; - - // Graph attributes. - map graph_attr = 6; -}; diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 675f0b1fd6ede56e38da3bd9dad4ae61e11a9be8..a5d8b061b6b26f9d05be40a1162481ae219b0e9c 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -67,6 +67,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = [ ":trt_logging", + ":trt_plugins", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]) + tf_custom_op_library_additional_deps(), @@ -86,6 +87,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":trt_logging", + ":trt_plugins", ":trt_resources", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", @@ -190,7 +192,7 @@ tf_py_wrap_cc( ":trt_conversion", ":trt_engine_op_kernel", "//tensorflow/core:framework_lite", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -232,6 +234,7 @@ tf_cuda_library( ], deps = [ ":segment", + ":trt_plugins", ":trt_logging", ":trt_resources", "//tensorflow/core/grappler/clusters:cluster", @@ -263,7 +266,6 @@ cc_library( "segment/segment.h", "segment/union_find.h", ], - linkstatic = 1, deps = [ "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", @@ -286,6 +288,46 @@ tf_cc_test( ], ) +# Library for the plugin factory +tf_cuda_library( + name = "trt_plugins", + srcs = [ + "plugin/trt_plugin.cc", + "plugin/trt_plugin_factory.cc", + "plugin/trt_plugin_utils.cc", + ], + hdrs = [ + "plugin/trt_plugin.h", + "plugin/trt_plugin_factory.h", + "plugin/trt_plugin_utils.h", + ], + deps = [ + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_cuda_cc_test( + name = "trt_plugin_factory_test", + size = "small", + srcs = ["plugin/trt_plugin_factory_test.cc"], + tags = [ + "manual", + "notap", + ], + deps = [ + ":trt_plugins", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:nv_infer", + ]), +) + py_test( name = "tf_trt_integration_test", srcs = ["test/tf_trt_integration_test.py"], diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 4df54a749f5a2dc4884fb437a7a16cd3bb51fa17..da4dd5a14cd74591fc9df63cd5868044e4e369ec 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -77,7 +78,8 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(ben,jie): ... }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) - return candidate_ops.count(node->type_string()); + return (candidate_ops.count(node->type_string()) || + PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); } void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, @@ -89,8 +91,11 @@ void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, if (!subgraph_node_ids.count(edge->src()->id()) && !edge->src()->IsSource() && !edge->IsControlEdge()) { incoming_edges->insert(edge); + VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() + << " Y, "; } else { - VLOG(2) << node->name() << " -> " << edge->src()->name() << " N, "; + VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() + << " N, "; } } } @@ -104,10 +109,12 @@ void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, for (const tensorflow::Edge* edge : node->out_edges()) { if (!subgraph_node_ids.count(edge->dst()->id()) && !edge->dst()->IsSink() && !edge->IsControlEdge()) { - VLOG(2) << node->name() << " -> " << edge->dst()->name() << " Y, "; + VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() + << " Y, "; outgoing_edges->insert(edge); } else { - VLOG(2) << node->name() << " -> " << edge->dst()->name() << " N, "; + VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() + << " N, "; } } } @@ -179,29 +186,27 @@ struct ConvertGraphParams { static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) { GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids, &p->subgraph_incoming_edges); + + std::set> unique_tensors; + // Add only unique input source nodes. If output of an outside node is shared + // between multiple nodes inside the engine, only one edge should be created for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) { - p->subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); - } - auto output_name_to_index_map = BuildTensorNameMap(p->output_names); - std::set> subgraph_outputs_set; - // Collect outputs referenced from output_names - for (int node_id : p->subgraph_node_ids) { - tensorflow::Node* node = p->graph.FindNodeId(node_id); - if (output_name_to_index_map.count(node->name())) { - for (int index : output_name_to_index_map.at(node->name())) { - subgraph_outputs_set.insert({node_id, index}); - } - } + unique_tensors.insert({edge->src()->id(), edge->src_output()}); } + p->subgraph_inputs.insert(p->subgraph_inputs.begin(), unique_tensors.begin(), + unique_tensors.end()); GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids, &p->subgraph_outgoing_edges); + unique_tensors.clear(); + // Similar to above, if multiple ouside nodes are sharing the output of an + // internal node only one output port should be created and shared between + // outputs for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) { - subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()}); + unique_tensors.insert({edge->src()->id(), edge->src_output()}); } - p->subgraph_outputs.reserve(subgraph_outputs_set.size()); + p->subgraph_outputs.reserve(unique_tensors.size()); p->subgraph_outputs.insert(p->subgraph_outputs.begin(), - subgraph_outputs_set.begin(), - subgraph_outputs_set.end()); + unique_tensors.begin(), unique_tensors.end()); return tensorflow::Status::OK(); } @@ -223,7 +228,6 @@ tensorflow::Status GetCalibNode(ConvertGraphParams* params) { for (auto in_edge : params->subgraph_incoming_edges) { // loop over incoming edges and // attach them to calib node - // tensorflow::Node* src_node = in_edge->src(); auto src_output = in_edge->src_output(); auto dst_node = in_edge->dst(); auto dst_input = in_edge->dst_input(); @@ -255,19 +259,24 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) { subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i}); } + std::set> unique_tensors; for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) { std::pair old_src = {edge->src()->id(), edge->src_output()}; + if (unique_tensors.count(old_src)) continue; + unique_tensors.insert(old_src); int new_src_output = subgraph_edge_to_input_map.at(old_src); params->graph.AddEdge(edge->src(), edge->src_output(), trt_node, new_src_output); + VLOG(1) << "Wire " << edge->src()->name() << ":" << edge->src_output() + << " -> " << trt_node->name() << ":" << new_src_output; params->graph.RemoveEdge(edge); } - - VLOG(2) << "new wiring edges: " << trt_node->in_edges().size(); - for (const tensorflow::Edge* edge : trt_node->in_edges()) { - VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); + if (VLOG_IS_ON(2)) { + VLOG(2) << "new edge count: " << trt_node->in_edges().size(); + for (const tensorflow::Edge* edge : trt_node->in_edges()) { + VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); + } } - TF_RETURN_IF_ERROR(status); // Re-map outgoing edges to use the new TRT node instead of the orig subgraph @@ -281,6 +290,8 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { int new_src_output = subgraph_edge_to_output_map.at(old_src); TF_RETURN_IF_ERROR(params->graph.UpdateEdge( trt_node, new_src_output, edge->dst(), edge->dst_input())); + VLOG(1) << "Wire " << trt_node->name() << ":" << new_src_output << " -> " + << edge->dst()->name() << ":" << edge->dst_input(); } // Remove the original subgraph for (int node_id : params->subgraph_node_ids) { @@ -315,9 +326,12 @@ tensorflow::Status ConvertCalibGraphToInferGraph( tensorflow::GraphConstructorOptions(), graph_def, &graph)); // get calib nodes std::vector calib_nodes; - for (auto node : graph.op_nodes()) { + std::vector topo_order; + tensorflow::GetPostOrder(graph, &topo_order); + for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { + auto node = *rit; if (node->type_string() == "TRTCalibOp") { - VLOG(1) << "Found Calib Node"; + VLOG(1) << "Found Calib Node " << node->name(); calib_nodes.push_back(node); } } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index be559d30e0041736f2b16aa6b72d9ffbd363b518..4e4d295538edadd26a347a38ec141737f097f26f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -240,35 +241,49 @@ class TFAttrs { return attrs_.at(key); } template - T get(string key) const; + T get(const string& key) const; template - T get(string key, const T& default_value) const { + T get(const string& key, const T& default_value) const { return attrs_.count(key) ? this->get(key) : default_value; } + std::vector GetAllAttrKey() { + std::vector attr_list; + for (const auto& attr_item : attrs_) { + attr_list.emplace_back(attr_item.first); + } + return attr_list; + } + private: typedef std::map AttrMap; AttrMap attrs_; }; template <> -string TFAttrs::get(string key) const { +string TFAttrs::get(const string& key) const { return this->at(key)->s(); } template <> -std::vector TFAttrs::get>(string key) const { +std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().i(); return std::vector(attr.begin(), attr.end()); } template <> -std::vector TFAttrs::get>(string key) const { +std::vector TFAttrs::get>(const string& key) const { + auto attr = this->at(key)->list().f(); + return std::vector(attr.begin(), attr.end()); +} + +template <> +std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().s(); return std::vector(attr.begin(), attr.end()); } template <> -nvinfer1::Dims TFAttrs::get(string key) const { +nvinfer1::Dims TFAttrs::get(const string& key) const { auto values = this->get>(key); nvinfer1::Dims dims; dims.nbDims = values.size(); @@ -278,24 +293,25 @@ nvinfer1::Dims TFAttrs::get(string key) const { } template <> -nvinfer1::DataType TFAttrs::get(string key) const { +nvinfer1::DataType TFAttrs::get(const string& key) const { nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); return trt_dtype; } template <> -tensorflow::DataType TFAttrs::get(string key) const { +tensorflow::DataType TFAttrs::get( + const string& key) const { return this->at(key)->type(); } template <> -float TFAttrs::get(string key) const { +float TFAttrs::get(const string& key) const { return this->at(key)->f(); } template <> -bool TFAttrs::get(string key) const { +bool TFAttrs::get(const string& key) const { return this->at(key)->b(); } @@ -346,10 +362,11 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights, break; } case tensorflow::DataType::DT_HALF: { - Reorder2({k, c}, static_cast(iweights.GetValues()), - istrides, static_cast( - const_cast(oweights->GetValues())), - ostrides); + Reorder2( + {k, c}, static_cast(iweights.GetValues()), + istrides, + static_cast(const_cast(oweights->GetValues())), + ostrides); break; } default: @@ -424,6 +441,7 @@ using OpConverter = class Converter { std::unordered_map trt_tensors_; std::unordered_map op_registry_; + OpConverter plugin_converter_; nvinfer1::INetworkDefinition* trt_network_; std::list> temp_bufs_; tensorflow::tensorrt::TRTWeightStore* weight_store_; @@ -490,13 +508,17 @@ class Converter { std::vector inputs; TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs)); string op = node_def.op(); - if (!op_registry_.count(op)) { - return tensorflow::errors::Unimplemented( - "No converter registered for op: " + op); - } - OpConverter op_converter = op_registry_.at(op); std::vector outputs; - TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); + if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) { + TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs)); + } else { + if (!op_registry_.count(op)) { + return tensorflow::errors::Unimplemented( + "No converter registered for op: " + op); + } + OpConverter op_converter = op_registry_.at(op); + TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); + } for (size_t i = 0; i < outputs.size(); ++i) { TRT_TensorOrWeights output = outputs.at(i); // TODO(jie): tf protobuf seems to be omitting the :0 suffix @@ -1158,9 +1180,9 @@ tensorflow::Status BinaryTensorOpTensor( CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) - return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + - " not supported at: " + - node_def.name()); + return tensorflow::errors::Unimplemented( + "binary op: " + node_def.op() + + " not supported at: " + node_def.name()); nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( *const_cast(tensor_l), @@ -1173,6 +1195,45 @@ tensorflow::Status BinaryTensorOpTensor( return tensorflow::Status::OK(); } +tensorflow::Status ConvertPlugin(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + // prepare input + std::vector all_inputs; + for (auto input : inputs) { + all_inputs.emplace_back(const_cast(input.tensor())); + } + + // plugin is owned by PluginFactory + // TODO(jie): destroy plugins later (resource management) + PluginTensorRT* plugin = + PluginFactoryTensorRT::GetInstance()->CreatePlugin(node_def.op()); + + // passing attributes + // TODO(jie): support more general attribute + TFAttrs attrs(node_def); + auto attr_key_vector = attrs.GetAllAttrKey(); + for (auto attr_key : attr_key_vector) { + // TODO(jie): support only list of float for toy example here. + auto data = attrs.get>(attr_key); + size_t size_data = data.size() * sizeof(float); + if (!plugin->SetAttribute(attr_key, static_cast(data.data()), + size_data)) { + return tensorflow::errors::InvalidArgument("plugin SetAttribute failed"); + } + } + + nvinfer1::IPluginLayer* layer = ctx.network()->addPlugin( + &all_inputs[0], static_cast(inputs.size()), *plugin); + + for (int i = 0; i < layer->getNbOutputs(); i++) { + nvinfer1::ITensor* output_tensor = layer->getOutput(i); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + } + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertPlaceholder( Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, @@ -2073,12 +2134,12 @@ void Converter::register_op_converters() { op_registry_["Reshape"] = ConvertReshape; op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; + + plugin_converter_ = ConvertPlugin; } } // namespace -tensorflow::Status GetTensorRTGraph(tensorrt::convert::SubGraphParams& s) { - return tensorflow::errors::Unimplemented("Not implemented yet"); -} + tensorflow::Status ConvertCalibrationNodeToEngineNode( tensorflow::Graph& graph, tensorflow::Node* c_node) { const auto ndef = c_node->def(); @@ -2102,9 +2163,23 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( for (auto n : graph.op_nodes()) { node_maps.insert({n->name(), n}); } + std::set subgraph_ids; + for (const auto internal_node : segment_nodes) { + subgraph_ids.insert(node_maps.at(internal_node)->id()); + } + if (VLOG_IS_ON(2)) { + string node_names = StrCat(c_node->name(), " segment nodes= "); + + for (const auto& node_name : segment_nodes) { + StrAppend(&node_names, node_name, ", "); + } + VLOG(2) << node_names; + } + VLOG(1) << "Output Nodes:"; std::vector out_types; std::vector out_edges; + for (auto& i : output_nodes) { auto node_port = tensorflow::str_util::Split(i, ":"); VLOG(1) << " " << i << " in graph " << node_maps.count(i); @@ -2124,18 +2199,24 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( out_types.push_back(out_node->output_type(0)); } for (auto out_edge : out_node->out_edges()) { + if (subgraph_ids.count(out_edge->dst()->id())) + continue; // skip internal edges; if (out_edge->src_output() == port) { out_edges.push_back(out_edge); - break; + VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":" + << out_edge->src_output() << " -> " << out_edge->dst()->name() + << ":" << out_edge->dst_input(); } } } else { LOG(WARNING) << " couldn't find output node " << out_node_name; } } - VLOG(1) << "Input Nodes:"; - for (auto& i : input_names) { - VLOG(1) << " " << i << " in graph " << node_maps.count(i); + if (VLOG_IS_ON(1)) { + VLOG(1) << c_node->name() << " Input Nodes:"; + for (auto& i : input_names) { + VLOG(1) << " Input " << i << " in graph " << node_maps.count(i); + } } auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); auto resmgr = trt_rm->getManager("TRTCalibOps"); @@ -2169,14 +2250,24 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( calib_res->builder_ = nullptr; tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); std::vector income_edges; + income_edges.resize(c_node->num_inputs()); for (const auto in_edge : c_node->in_edges()) { auto src = in_edge->src(); int dest_port = in_edge->dst_input(); - income_edges.emplace_back(src->name(), in_edge->src_output(), - c_node->input_type(dest_port)); + VLOG(1) << "Incoming connection " << src->name() << ":" + << in_edge->src_output() << " -> " << c_node->name() << ":" + << dest_port; + income_edges.at(dest_port) = {src->name(), in_edge->src_output(), + c_node->input_type(dest_port)}; } tensorflow::gtl::ArraySlice input_list( income_edges); + if (VLOG_IS_ON(2)) { + for (const auto& inp : input_list) { + VLOG(2) << " Input from inputlist " << inp.node << ":" << inp.index << " " + << tensorflow::DataTypeString(inp.data_type); + } + } op_builder.Input(input_list); tensorflow::NodeDef engine_node; const char* engine_plan_data = static_cast(engine_plan->data()); @@ -2193,13 +2284,26 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( } auto trt_engine_node = graph.AddNode(engine_node, &status); TF_RETURN_IF_ERROR(status); - for (size_t i = 0; i < out_edges.size(); i++) { - VLOG(1) << "Connecting trt_engine_node output " << i << " with " - << out_edges.at(i)->dst()->name() << " port " - << out_edges.at(i)->dst_input(); - TF_RETURN_IF_ERROR(graph.UpdateEdge(trt_engine_node, i, - out_edges.at(i)->dst(), - out_edges.at(i)->dst_input())); + std::map port_map; + for (size_t t = 0; t < output_nodes.size(); t++) { + port_map.insert({output_nodes.at(t), t}); + } + for (auto& i : out_edges) { + string s(i->src()->name()); + if (i->src_output()) StrAppend(&s, ":", i->src_output()); + int out_port = port_map.at(s); + VLOG(1) << "Connecting " << trt_engine_node->name() << ":" << out_port + << " -> " << i->dst()->name() << ":" << i->dst_input(); + TF_RETURN_IF_ERROR( + graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input())); + } + for (const auto ed : trt_engine_node->in_edges()) { + VLOG(1) << "In Edge " << ed->src()->name() << ":" << ed->src_output() + << " -> " << ed->dst()->name() << ":" << ed->dst_input(); + } + for (const auto ed : trt_engine_node->out_edges()) { + VLOG(1) << "Out Edge " << ed->src()->name() << ":" << ed->src_output() + << " -> " << ed->dst()->name() << ":" << ed->dst_input(); } VLOG(1) << "Segment nodes:"; for (auto& i : segment_nodes) { @@ -2270,6 +2374,7 @@ tensorflow::Status ConvertSubgraph( std::vector* output_names, std::vector* output_dtypes, const string& engine_name) { + std::set added_tensors; for (const std::pair& input : s.input_inds) { VLOG(2) << "parsing input. Node id= " << input.first; int node_id = input.first; @@ -2312,7 +2417,6 @@ tensorflow::Status ConvertSubgraph( auto op_info = op_info_vec.at(shape_inference_output_idx); tensorflow::DataType tf_dtype = op_info.dtype(); - input_dtypes->push_back(tf_dtype); nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); auto type_status = ConvertDType(tf_dtype, &dtype); @@ -2348,8 +2452,10 @@ tensorflow::Status ConvertSubgraph( if (output_idx != 0) { input_tensor_name = StrCat(node_name, ":", output_idx); } - + if (added_tensors.count(input_tensor_name)) continue; + added_tensors.insert(input_tensor_name); input_names->push_back(input_tensor_name); + input_dtypes->push_back(tf_dtype); nvinfer1::ITensor* input_tensor = converter.network()->addInput( input_tensor_name.c_str(), dtype, input_dim_pseudo_chw); @@ -2373,6 +2479,7 @@ tensorflow::Status ConvertSubgraph( // Gather output metadata int trt_engine_op_output_idx = 0; + added_tensors.clear(); for (const std::pair& output : s.output_inds) { int node_id = output.first; int output_idx = output.second; @@ -2389,6 +2496,8 @@ tensorflow::Status ConvertSubgraph( if (output_idx != 0) tensorflow::strings::StrAppend(&tensor_name, ":", output_idx); VLOG(2) << "Output tensor name: " << tensor_name; + if (added_tensors.count(tensor_name)) continue; + added_tensors.insert(tensor_name); output_names->push_back(tensor_name); auto tensor_or_weights = converter.get_tensor(tensor_name); if (!tensor_or_weights.is_tensor()) { @@ -2472,7 +2581,7 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { // Build the TRT op // TODO(sami,ben,jie): proper naming! tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp"); - SetInputList(s, &op_builder, &input_names, &input_dtypes); + TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes)); std::vector segment_names; segment_names.reserve(s.subgraph_node_ids.size()); @@ -2570,7 +2679,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( // Build the TRT op tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); - SetInputList(s, &op_builder, &input_names, &input_dtypes); + TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes)); VLOG(0) << "Finished op preparation"; diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a89cf3ab8bfaecc74fc5890ccb7e7a7147278182 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -0,0 +1,118 @@ +# Description: +# Example for plugin support in TensorRT(http://developer.nvidia.com/tensorrt) +# through TensorFlow integration. Targeting TensorRT 3.0.4 +# APIs are meant to change while upgrading TRT. +# add init_py into pip package BUILD dependency to install it. + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_custom_op_library_additional_deps", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", +) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load( + "@local_config_tensorrt//:build_defs.bzl", + "if_tensorrt", +) + +tf_gen_op_libs( + op_lib_names = ["inc_op"], +) + +tf_gen_op_wrapper_py( + name = "inc_op", + deps = [":inc_op_op_lib"], +) + +tf_custom_op_library( + name = "_inc_op.so", + srcs = [ + "inc_op_kernel.h", + "inc_op_plugin.cc", + "inc_op_plugin.h", + "ops/inc_op.cc", + ], + gpu_srcs = [ + "inc_op_kernel.h", + "inc_op_kernel.cu.cc", + ], + deps = [ + "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/core:framework_lite", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_kernel_library( + name = "inc_op_plugin_kernel", + srcs = ["inc_op_plugin.cc"], + hdrs = [ + "inc_op_kernel.h", + "inc_op_plugin.h", + ], + gpu_srcs = [ + "inc_op_kernel.h", + "inc_op_kernel.cu.cc", + ], + deps = [ + "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/core:stream_executor_headers_lib", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), +) + +tf_custom_op_py_library( + name = "inc_op_loader", + srcs = ["inc_op.py"], + dso = [ + ":_inc_op.so", + ], + kernels = [ + ":inc_op_op_lib", + ":inc_op_plugin_kernel", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:resources", + ], +) + +py_library( + name = "init_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":inc_op", + ":inc_op_loader", + ], +) + +cuda_py_test( + name = "plugin_test", + size = "small", + srcs = ["plugin_test.py"], + additional_deps = [ + ":init_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/contrib/tensorrt:init_py", + "//tensorflow/python:platform", + "//tensorflow/python:client_testlib", + "//tensorflow/python:tf_optimizer", + ], + tags = [ + "manual", + "noguitar", + "notap", + ], +) diff --git a/tensorflow/python/keras/datasets/cifar100/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py similarity index 66% rename from tensorflow/python/keras/datasets/cifar100/__init__.py rename to tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py index ca93742673341660ba69712feb59c5dd32ea3252..363edab2e80ada5c5d52ae7ff66ff4af678b251d 100644 --- a/tensorflow/python/keras/datasets/cifar100/__init__.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,15 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== -"""CIFAR100 small image classification dataset.""" +# ============================================================================= +"""Import custom op for plugin and register it in plugin factory registry.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar100 import load_data +from tensorflow.contrib.tensorrt.custom_plugin_examples import inc_op as import_inc_op_so +from tensorflow.contrib.tensorrt.custom_plugin_examples.ops import gen_inc_op -del absolute_import -del division -del print_function +inc_op = gen_inc_op.inc_plugin_trt diff --git a/tensorflow/python/keras/applications/nasnet/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py similarity index 64% rename from tensorflow/python/keras/applications/nasnet/__init__.py rename to tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py index 94eb145b85b85b2e52ca37e7aebc681c1f054e16..a007c3f54e208b7623db128f4069c2343d0283c8 100644 --- a/tensorflow/python/keras/applications/nasnet/__init__.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py @@ -11,18 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== -"""NASNet Keras applications.""" +# ============================================================================= +"""Loader for the custom inc_op.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.nasnet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetLarge -from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetMobile -from tensorflow.python.keras._impl.keras.applications.nasnet import preprocess_input +import platform -del absolute_import -del division -del print_function +if platform.system() != "Windows": + # pylint: disable=g-import-not-at-top + from tensorflow.contrib.util import loader + from tensorflow.python.platform import resource_loader + # pylint: enable=g-import-not-at-top + + _inc_op = loader.load_op_library( + resource_loader.get_path_to_datafile("_inc_op.so")) +else: + raise RuntimeError("Windows not supported") diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..988b35f74f3989481f59c52c6320623a26704327 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" + +#include + +#include "tensorflow/core/framework/op_kernel.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace tensorflow { +namespace tensorrt { + +__global__ void VecInc(const float* vec, float inc, float* dest, int n) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < n) dest[i] = vec[i] + inc; +} + +void IncrementKernel(const float* d_input, float inc, float* d_output, + int count, cudaStream_t stream) { + int threads_per_block = 256; + int blocks_per_grid = (count + threads_per_block - 1) / threads_per_block; + + VecInc<<>>(d_input, inc, + d_output, count); +} + +// Note: this kernel definition is not needed in the plugin_test rule, but it is +// required for correctness of the TF program, i.e. if not using plugin or when +// run with trt optimization pass, the test should work. +class IncPluginTRT : public OpKernel { + public: + explicit IncPluginTRT(OpKernelConstruction* context) : OpKernel(context) { + std::vector inc_list; + OP_REQUIRES_OK(context, context->GetAttr("inc", &inc_list)); + OP_REQUIRES(context, inc_list.size() == 1, + errors::InvalidArgument( + "The increment list should contain single element.")); + inc_ = inc_list[0]; + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_tensor = context->input(0); + const TensorShape& input_shape = input_tensor.shape(); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_shape, &output_tensor)); + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast(context->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + IncrementKernel(input_tensor.flat().data(), inc_, + output_tensor->flat().data(), + input_shape.num_elements(), *stream); + } + + private: + float inc_; +}; + +REGISTER_KERNEL_BUILDER(Name("IncPluginTRT").Device(DEVICE_GPU), IncPluginTRT); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c35955e105798b20f93f650624eac24f378beb0b --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_KERNEL_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_KERNEL_H_ + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" + +namespace tensorflow { +namespace tensorrt { + +void IncrementKernel(const float* d_input, float inc, float* d_output, + int count, cudaStream_t stream); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_KERNEL_H_ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d4c893af56689185da72398919e2241d451594b --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -0,0 +1,86 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" + +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +const char* kPluginName = "IncPluginTRT"; + +IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); } + +IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { + return new IncOpPlugin(buffer, length); +} + +REGISTER_TRT_PLUGIN(kPluginName, CreateIncPluginDeserialize, CreateIncPlugin); + +IncOpPlugin::IncOpPlugin() : plugin_name_(kPluginName) {} + +IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) + : PluginTensorRT(serialized_data, length), plugin_name_(kPluginName) { + // account for the consumed pointer. + size_t consumed_data = PluginTensorRT::getSerializationSize(); + assert(length - consumed_data >= sizeof(float)); + const char* buffer = reinterpret_cast(serialized_data); + SetAttribute("inc", buffer + consumed_data, sizeof(float)); +} + +bool IncOpPlugin::SetAttribute(const string& key, const void* ptr, + const size_t size) { + if (strcmp(key.c_str(), "inc") == 0 && size == sizeof(float)) { + StoreAttribute(key, ptr, size); // save the attribute to own the data; + inc_ = *static_cast(ptr); + return true; + } + return false; +} + +bool IncOpPlugin::GetAttribute(const string& key, const void** ptr, + size_t* size) const { + const auto& iter = attr_map_.find(key); + if (iter != attr_map_.end()) { + *ptr = iter->second.data(); + *size = iter->second.size(); + return true; + } + return false; +} + +int IncOpPlugin::enqueue(int batch_size, const void* const* inputs, + void** outputs, void*, cudaStream_t stream) { + int count = 1; + for (int i = 0; i < input_dim_list_[0].nbDims; i++) { + count *= input_dim_list_[0].d[i]; + } + count *= batch_size; + const float* input = reinterpret_cast(inputs[0]); + float* output = reinterpret_cast(outputs[0]); + IncrementKernel(input, inc_, output, count, stream); + return 0; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..189e9c939b9ffd4450f7ba95fe1abdbbc049b430 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_PLUGIN_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_PLUGIN_H_ + +#include +#include + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +class IncOpPlugin : public PluginTensorRT { + public: + IncOpPlugin(); + + IncOpPlugin(const void* serialized_data, size_t length); + + const string& GetPluginName() const override { return plugin_name_; }; + + bool Finalize() override { return true; }; + + bool SetAttribute(const string& key, const void* ptr, + const size_t size) override; + + bool GetAttribute(const string& key, const void** ptr, + size_t* size) const override; + + int getNbOutputs() const override { return 1; } + + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int num_input_dims) override { + assert(index == 0); + assert(num_input_dims == 1); + return inputs[0]; + } + + // use configure to setup input dimensions + void configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) override { + assert(num_inputs == 1); + PluginTensorRT::configure(inputs, num_inputs, outputs, num_outputs, + max_batch_size); + } + + int initialize() override { return 0; } + + void terminate() override {} + + size_t getWorkspaceSize(int max_batch_size) const override { return 0; } + + int enqueue(int batch_size, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override; + + size_t getSerializationSize() override { + return PluginTensorRT::getSerializationSize() + sizeof(float); + } + + void serialize(void* buffer) override { + // Serialize parent data. + PluginTensorRT::serialize(buffer); + // Incremented buffer after parent serialization. + buffer = + static_cast(buffer) + PluginTensorRT::getSerializationSize(); + std::memcpy(buffer, &inc_, sizeof(float)); + buffer = static_cast(buffer) + sizeof(float); + } + + protected: + float inc_; + nvinfer1::Dims dim_; + + private: + const string plugin_name_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_PLUGIN_H_ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d0eb0d299dd61dcc5c889e61994e6430340cdb1d --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { + +REGISTER_OP("IncPluginTRT") + .Attr("inc: list(float)") + .Input("input: float32") + .Output("output: float32") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4d270bec4fb83d8ea067fca3a750270755a659 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to show usage of TensorRT custom op & plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy + +from tensorflow.contrib import tensorrt +from tensorflow.contrib.tensorrt import custom_plugin_examples +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class TrtPluginTest(test_util.TensorFlowTestCase): + + def _get_plugin_graph_def(self): + """Create a simple graph and return its graph_def.""" + g = ops.Graph() + with g.as_default(): + a = array_ops.placeholder( + dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") + relu = nn.relu(a, "relu") + v = nn_ops.max_pool( + relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + + # insert custom_op in the graph + v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") + + v *= 2.0 + v = nn.relu(v) + v = nn.relu(v) + array_ops.squeeze(v, name="output") + return g.as_graph_def() + + def _run_graph(self, gdef, dumm_inp): + """Run given graphdef once.""" + gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + ops.reset_default_graph() + g = ops.Graph() + with g.as_default(): + inp, out = importer.import_graph_def( + graph_def=gdef, return_elements=["input", "output"]) + inp = inp.outputs[0] + out = out.outputs[0] + + with session.Session( + config=config_pb2.ConfigProto(gpu_options=gpu_options), + graph=g) as sess: + val = sess.run(out, {inp: dumm_inp}) + return val + + def testIncOpPlugin(self): + inp_dims = (5, 24, 24, 2) + dummy_input = numpy.ones(inp_dims).astype(numpy.float32) + orig_graph = self._get_plugin_graph_def() # graph with plugin node + + # trigger conversion. + # plugin nodes have been registered during import, converter will be able to + # create corresponding plugin layer during conversion. + trt_graph = tensorrt.create_inference_graph( + input_graph_def=orig_graph, + outputs=["output"], + max_batch_size=inp_dims[0], + max_workspace_size_bytes=1 << 25, + precision_mode="FP32", + minimum_segment_size=2) + o2 = self._run_graph(trt_graph, dummy_input) + self.assertEqual(35, o2.reshape([-1])[0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 5c5b2e3c073d5fc38a8505e9f2c7ddf117cb4ffd..9ac8047944874181de228a6cc58e2dafe46abe50 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" @@ -59,7 +60,8 @@ void TRTEngineOp::Compute(OpKernelContext* context) { infer->setGpuAllocator(allocator_.get()); #endif trt_engine_ptr_.reset(infer->deserializeCudaEngine( - serialized_engine_.c_str(), serialized_engine_.size(), nullptr)); + serialized_engine_.c_str(), serialized_engine_.size(), + PluginFactoryTensorRT::GetInstance())); trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); // Runtime is safe to delete after engine creation infer->destroy(); diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h index 7f3544f8cfda8dce13881e1f8f4388b640e315f4..96ccacb791e40143c5c4d9d691bb353702f9a28b 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -28,7 +28,7 @@ namespace tensorrt { // Logger for GIE info/warning/errors class Logger : public nvinfer1::ILogger { public: - Logger(string name = "DefaultLogger") : name_(name){}; + Logger(string name = "DefaultLogger") : name_(name) {} void log(nvinfer1::ILogger::Severity severity, const char* msg) override; private: diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc new file mode 100644 index 0000000000000000000000000000000000000000..062f86e8bb4dc753925e4e2baf0bc80a5312a94f --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -0,0 +1,106 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include +#include +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { + const char* buffer = static_cast(serialized_data); + size_t op_name_char_count = *reinterpret_cast(buffer); + buffer += sizeof(size_t); + buffer += op_name_char_count; + + size_t count = *reinterpret_cast(buffer); + buffer += sizeof(size_t); + + for (int i = 0; i < count; i++) { + nvinfer1::Dims dim; + std::memcpy(&(dim.nbDims), buffer, sizeof(dim.nbDims)); + buffer += sizeof(dim.nbDims); + std::memcpy(dim.d, buffer, sizeof(dim.d)); + buffer += sizeof(dim.d); + std::memcpy(dim.type, buffer, sizeof(dim.type)); + buffer += sizeof(dim.type); + input_dim_list_.emplace_back(dim); + } +} + +void PluginTensorRT::configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) { + for (int index = 0; index < num_inputs; index++) { + nvinfer1::Dims dim; + dim.nbDims = inputs[index].nbDims; + for (int i = 0; i < dim.nbDims; i++) { + dim.d[i] = inputs[index].d[i]; + dim.type[i] = inputs[index].type[i]; + } + input_dim_list_.emplace_back(dim); + } +} + +size_t PluginTensorRT::getSerializationSize() { + nvinfer1::Dims dim; + return sizeof(size_t) + GetPluginName().size() + + sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + sizeof(dim.d) + + sizeof(dim.type); +} + +void PluginTensorRT::serialize(void* serialized_data) { + size_t op_name_size = GetPluginName().size(); + char* buffer = static_cast(serialized_data); + std::memcpy(buffer, &op_name_size, sizeof(size_t)); + buffer += sizeof(size_t); + + std::memcpy(buffer, GetPluginName().data(), op_name_size); + buffer += op_name_size; + + auto list_size = input_dim_list_.size(); + std::memcpy(buffer, &list_size, sizeof(input_dim_list_.size())); + buffer += sizeof(input_dim_list_.size()); + + for (int i = 0; i < input_dim_list_.size(); i++) { + auto dim = input_dim_list_[i]; + std::memcpy(buffer, &(dim.nbDims), sizeof(dim.nbDims)); + buffer += sizeof(dim.nbDims); + std::memcpy(buffer, dim.d, sizeof(dim.d)); + buffer += sizeof(dim.d); + std::memcpy(buffer, dim.type, sizeof(dim.type)); + buffer += sizeof(dim.type); + } +} + +bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr, + const size_t size) { + if (attr_map_.count(key) != 0) return false; + + attr_map_.emplace(key, std::vector(size)); + std::memcpy(attr_map_[key].data(), ptr, size); + return true; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..754920b60ca7439513a91ad0354833a2482b29c1 --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +// A wrapper class for TensorRT plugin +// User application should inherit from this class to write custom kernels. +// Allows user to insert custom op in TensorRT engine +// To register plugin in converter, user should also register custom +// PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT +class PluginTensorRT : public nvinfer1::IPlugin { + public: + PluginTensorRT() {} + PluginTensorRT(const void* serialized_data, size_t length); + + virtual const string& GetPluginName() const = 0; + + virtual bool Finalize() = 0; + + virtual bool SetAttribute(const string& key, const void* ptr, + const size_t size) = 0; + virtual bool GetAttribute(const string& key, const void** ptr, + size_t* size) const = 0; + + void configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) override; + + virtual bool StoreAttribute(const string& key, const void* ptr, + const size_t size); + + size_t getSerializationSize() override; + + void serialize(void* buffer) override; + + protected: + std::unordered_map > attr_map_; + + std::vector input_dim_list_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc new file mode 100644 index 0000000000000000000000000000000000000000..2bc591484dcaf5b35c39f3d0523dd89dcd152e6a --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, + const void* serial_data, + size_t serial_length) { + size_t parsed_byte = 0; + // extract op_name from serial_data + string encoded_op_name = + ExtractOpName(serial_data, serial_length, &parsed_byte); + + if (!IsPlugin(encoded_op_name)) { + return nullptr; + } + + tensorflow::mutex_lock lock(instance_m_); + auto plugin_ptr = + plugin_registry_[encoded_op_name].first(serial_data, serial_length); + owned_plugins_.emplace_back(plugin_ptr); + + return plugin_ptr; +} + +PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string& op_name) { + if (!IsPlugin(op_name)) return nullptr; + + tensorflow::mutex_lock lock(instance_m_); + auto plugin_ptr = plugin_registry_[op_name].second(); + owned_plugins_.emplace_back(plugin_ptr); + + return plugin_ptr; +} + +bool PluginFactoryTensorRT::RegisterPlugin( + const string& op_name, PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func) { + if (IsPlugin(op_name)) return false; + + tensorflow::mutex_lock lock(instance_m_); + auto ret = plugin_registry_.emplace( + op_name, std::make_pair(deserialize_func, construct_func)); + + return ret.second; +} + +void PluginFactoryTensorRT::DestroyPlugins() { + tensorflow::mutex_lock lock(instance_m_); + for (auto& owned_plugin_ptr : owned_plugins_) { + owned_plugin_ptr.release(); + } + owned_plugins_.clear(); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..bbae9fb65c22cf69d2e7954436fd04dd16f7f6c8 --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ + +#include +#include + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { + public: + // TODO(aaroey): this static method has to be inlined to make the singleton a + // unique global symbol. Find a way to fix it. + static PluginFactoryTensorRT* GetInstance() { + static PluginFactoryTensorRT* factory_instance = + new PluginFactoryTensorRT(); + return factory_instance; + } + + // Deserialization method + PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, + size_t serial_length) override; + + // Plugin construction, PluginFactoryTensorRT owns the plugin. + PluginTensorRT* CreatePlugin(const string& op_name); + + bool RegisterPlugin(const string& op_name, + PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func); + + bool IsPlugin(const string& op_name) { + return plugin_registry_.find(op_name) != plugin_registry_.end(); + } + + size_t CountOwnedPlugins() { return owned_plugins_.size(); } + + void DestroyPlugins(); + + protected: + std::unordered_map> + plugin_registry_; + + // TODO(jie): Owned plugin should be associated with different sessions; + // should really hand ownership of plugins to resource management; + std::vector> owned_plugins_; + tensorflow::mutex instance_m_; +}; + +class TrtPluginRegistrar { + public: + TrtPluginRegistrar(const string& name, PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func) { + auto factory = PluginFactoryTensorRT::GetInstance(); + QCHECK(factory->RegisterPlugin(name, deserialize_func, construct_func)) + << "Failed to register plugin: " << name; + } +}; + +#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \ + construct_func) +#define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \ + construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) +#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \ + static ::tensorflow::tensorrt::TrtPluginRegistrar trt_plugin_registrar##ctr \ + TF_ATTRIBUTE_UNUSED = ::tensorflow::tensorrt::TrtPluginRegistrar( \ + name, deserialize_func, construct_func) + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..129bdcdbc2f8d9d5215f45f381bcadf35e4fa75e --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc @@ -0,0 +1,125 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +class StubPlugin : public PluginTensorRT { + public: + static const char* kPluginName; + + StubPlugin() : plugin_name_(kPluginName) {} + + StubPlugin(const void* serialized_data, size_t length) + : PluginTensorRT(serialized_data, length) {} + + const string& GetPluginName() const override { return plugin_name_; } + + bool Finalize() override { return true; } + + bool SetAttribute(const string& key, const void* ptr, + const size_t size) override { + return true; + } + + bool GetAttribute(const string& key, const void** ptr, + size_t* size) const override { + return true; + } + + int getNbOutputs() const override { return 1; } + + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int nbInputDims) override { + return inputs[0]; + } + + int initialize() override { return 0; } + + void terminate() override {} + + size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } + + int enqueue(int batch_size, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override { + return 0; + } + + private: + const string plugin_name_; +}; + +const char* StubPlugin::kPluginName = "StubPlugin"; + +StubPlugin* CreateStubPlugin() { return new StubPlugin(); } + +StubPlugin* CreateStubPluginDeserialize(const void* serialized_data, + size_t length) { + return new StubPlugin(serialized_data, length); +} + +class TrtPluginFactoryTest : public ::testing::Test { + public: + bool RegisterStubPlugin() { + if (PluginFactoryTensorRT::GetInstance()->IsPlugin( + StubPlugin::kPluginName)) { + return true; + } + return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( + StubPlugin::kPluginName, CreateStubPluginDeserialize, CreateStubPlugin); + } +}; + +TEST_F(TrtPluginFactoryTest, Registration) { + EXPECT_FALSE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); + EXPECT_TRUE(RegisterStubPlugin()); + + ASSERT_TRUE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); +} + +TEST_F(TrtPluginFactoryTest, CreationDeletion) { + EXPECT_TRUE(RegisterStubPlugin()); + ASSERT_TRUE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); + + PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); + ASSERT_TRUE(PluginFactoryTensorRT::GetInstance()->CreatePlugin( + StubPlugin::kPluginName)); + ASSERT_EQ(1, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); + PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); + ASSERT_EQ(0, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); +} + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8f60886c03c174a612e7a135b6eb7bb7cb9997a --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental) { + size_t op_name_char_count = *static_cast(serial_data); + *incremental = sizeof(size_t) + op_name_char_count; + + assert(serial_length >= *incremental); + + const char* buffer = static_cast(serial_data) + sizeof(size_t); + string op_name(buffer, op_name_char_count); + + return op_name; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..274ce42fec9283c643004d45fba461879fc5f2dc --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ + +#include + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +typedef std::function + PluginDeserializeFunc; + +typedef std::function PluginConstructFunc; + +// TODO(jie): work on error handling here +string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 8038085a060dc92c3a046c7efe7d7a08ca97973a..f5b2d258d70d5577a9d68f2d9f6d6e678ede97ce 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/segment/segment.h" #include "tensorflow/c/c_api.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" @@ -276,13 +275,13 @@ TEST_F(SegmentTest, Multiple) { // Expect two subgraphs EXPECT_EQ(segments.size(), 2); - std::vector expected0{"add0", "add1", "add2", "add3"}; + std::vector expected0{"add6", "add8"}; for (const auto& ex : expected0) { EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end()) << "Missing expected node " << ex; } - std::vector expected1{"add6", "add8"}; + std::vector expected1{"add0", "add1", "add2", "add3"}; for (const auto& ex : expected1) { EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end()) << "Missing expected node " << ex; diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 8b475177bc670ddae2b26b6a494f758eba20b2c3..f36495f6b69ecb2f2a8d730b9ae4919fea3c04b8 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -33,7 +34,8 @@ tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine)); nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), nullptr); + serialized_engine.c_str(), serialized_engine.size(), + tensorrt::PluginFactoryTensorRT::GetInstance()); int num_batch = -1; std::vector<::tensorflow::DataType> input_type; diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py index a5c00dd6333183c87a41aa7effc32f814a2299e7..0403b652d72877196c3537a3181529aeeb997395 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py @@ -34,7 +34,6 @@ from tensorflow.python.ops import nn_ops as nn_ops from tensorflow.python.platform import googletest -@test_util.with_c_api class IntegrationTest(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration.""" diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index d2746032a04946cdfab4b5ac968ea3add5f6b51d..e4963596d38dbe8aea98fddbc67dbbf761c215c8 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -110,6 +110,7 @@ py_test( "no_pip_gpu", # b/63391119 "nomsan", # Takes too long to run. "notsan", # b/67865658 + "optonly", # Takes too long to run without optimization. ], deps = [ ":ar_model", diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index ce96180c9271b95991826c2527cec526c1397ae5..d8089453340e894db6af9fc3a3b360c9512207eb 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -30,9 +30,9 @@ from tensorflow.python.estimator import estimator_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.keras._impl.keras.engine import sequential -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py index 706742ca287a7d8a172472cd944c906e3eda335c..983455f63db07903a9b2996706c6dba731d5e2b8 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py @@ -68,15 +68,16 @@ class TimeSeriesRegressorTest(test.TestCase): eval_input_fn = input_pipeline.RandomWindowInputFn( input_pipeline.NumpyReader(features), shuffle_seed=3, num_threads=1, batch_size=16, window_size=16) - first_estimator.train(input_fn=train_input_fn, steps=5) + first_estimator.train(input_fn=train_input_fn, steps=1) first_loss_before_fit = first_estimator.evaluate( input_fn=eval_input_fn, steps=1)["loss"] - first_estimator.train(input_fn=train_input_fn, steps=50) + self.assertAllEqual([], first_loss_before_fit.shape) + first_estimator.train(input_fn=train_input_fn, steps=1) first_loss_after_fit = first_estimator.evaluate( input_fn=eval_input_fn, steps=1)["loss"] - self.assertLess(first_loss_after_fit, first_loss_before_fit) + self.assertAllEqual([], first_loss_after_fit.shape) second_estimator = estimator_fn(model_dir, exogenous_feature_columns) - second_estimator.train(input_fn=train_input_fn, steps=2) + second_estimator.train(input_fn=train_input_fn, steps=1) whole_dataset_input_fn = input_pipeline.WholeDatasetInputFn( input_pipeline.NumpyReader(features)) whole_dataset_evaluation = second_estimator.evaluate( diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index defed00537c407216703b3bf8651d33cdf311b56..ab2a7a0d4bec48d6b3b459bb3144e8ddae614ca0 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -25,6 +25,7 @@ using shape_inference::ShapeHandle; REGISTER_OP("TPUReplicateMetadata") .Attr("num_replicas: int >= 0") .Attr("topology: string = \"\"") + .Attr("use_tpu: bool = true") .Attr("device_assignment: list(int) = []") .Attr("computation_shape: list(int) = []") .Attr("host_compute_core: list(string) = []") @@ -72,6 +73,7 @@ REGISTER_OP("TPUReplicate") .Attr("computation: func") .Attr("num_replicas: int >= 1") .Attr("topology: string = \"\"") + .Attr("use_tpu: bool = true") .Attr("device_assignment: list(int) = []") .Attr("host_compute_core: list(string) = []") .Attr("computation_shape: list(int) = []") @@ -93,6 +95,9 @@ computation: a function containing the computation to run. num_replicas: the number of replicas of the computation to run. topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU topology. +use_tpu: a bool indicating if this computation will run on TPU or CPU/GPU. +Currently, only supports a default placement (computation is placed on GPU +if one is available, and on CPU if not). computation_shape: a [mesh_dimension] array describing the shape of each computation replica in numbers of cores in the TPU mesh. device_assignment: a flattened array with shape diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index dbf1ab6bbf0ddc7429d8e19279451eb862981e0c..3b2d7adfff6b8de3145a73756a8b5306445034c5 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -53,7 +53,7 @@ tf_cc_binary( "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/platform/cloud:gcs_file_system", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], ) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 816897499b7a49365060c026d60b977990f3ecdc..f80f5652af79d410946971573ae160fdd0b85f6d 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -18,7 +18,7 @@ limitations under the License. // Initiates a TPU profiling on the TPUProfiler service at service_addr, // receives and dumps the profile data to a tensorboard log directory. -#include "grpc++/grpc++.h" +#include "grpcpp/grpcpp.h" #include #include @@ -79,7 +79,9 @@ ProfileRequest PopulateProfileRequest(int duration_ms, request.set_repository_root(repository_root); request.set_session_id(session_id); } + request.add_tools("op_profile"); request.add_tools("input_pipeline"); + request.add_tools("memory_viewer"); request.add_tools("overview_page"); *request.mutable_opts() = opts; std::cout << "Limiting the number of trace events to " << kMaxEvents diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index 5e85a967ad4ea373e213fa90c3640e9ab1f92d25..98cc31f18d2d34765f2c123c3d34207802541036 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/contrib/tpu/profiler/op_profile.pb.h" #include "tensorflow/contrib/tpu/profiler/trace_events.pb.h" #include "tensorflow/contrib/tpu/profiler/trace_events_to_json.h" -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/compression.h" #include "tensorflow/core/lib/io/path.h" @@ -30,8 +29,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/core/util/event.pb.h" #include "tensorflow/core/util/events_writer.h" namespace tensorflow { @@ -41,6 +38,7 @@ namespace { using ::tensorflow::io::JoinPath; using ::tensorflow::protobuf::util::JsonOptions; using ::tensorflow::protobuf::util::MessageToJsonString; +using ::tensorflow::str_util::EndsWith; using ::tensorflow::strings::StrCat; constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph."; @@ -49,6 +47,9 @@ constexpr char kJsonTraceFileName[] = "trace.json.gz"; constexpr char kProfilePluginDirectory[] = "plugins/profile/"; constexpr char kProtoTraceFileName[] = "trace"; +constexpr char kFlatProfilerFileName[] = "flat_profiler.pb"; +constexpr char kTfStatsHelperSuffix[] = "tf_stats_helper_result"; + Status WriteGzippedDataToFile(const string& filename, const string& data) { std::unique_ptr file; TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(filename, &file)); @@ -110,6 +111,10 @@ Status DumpToolDataToLogDirectory(StringPiece run_dir, const string& host_prefix, const tensorflow::ProfileToolData& tool, std::ostream* os) { + // Don't save the intermediate results for combining the per host tool data. + if (EndsWith(tool.name(), kFlatProfilerFileName) || + EndsWith(tool.name(), kTfStatsHelperSuffix)) + return Status::OK(); string path = JoinPath(run_dir, StrCat(host_prefix, tool.name())); TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data())); if (os) { diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto index 840a43913ba0f159d3c495553ebdff79c0448e73..1f249de314a54067ffbe7193e3135912a091b10a 100644 --- a/tensorflow/contrib/tpu/profiler/op_profile.proto +++ b/tensorflow/contrib/tpu/profiler/op_profile.proto @@ -60,6 +60,11 @@ message Metrics { // - it does not reveal the peak core FLOPS of the hardware double flops = 2; + // The VMEM bandwidth used to load operands from HBM, as a fraction of + // thereotical VMEM bandwidth on the specific hardware. + double memory_bandwidth = 3; + double raw_time = 11; // Elapsed core-time in picoseconds. double raw_flops = 12; // Total floating-point operations performed. + double raw_bytes_accessed = 13; // Total bytes accessed (include read/write). } diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto index b9ac1a550c87e055fd5d555c346d5ec545bbe634..2b13343efa4e82386cb9259432b854be3ec821f7 100644 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -87,6 +87,8 @@ message StepInfoResult { optional uint64 wait_duration_ps = 5; // The time spent on cross-replica-sum in picoseconds. optional uint64 crs_duration_ps = 6; + // Percentage of unit b time spent on infeed. + optional double unit_b_infeed_percent = 7; } // Result proto for a sequence of steps. diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto index 7be694e866729c58efae4ccf7932dd929c03ed91..f0fca63db0bca80cdaa27e491b2a03ae2246c007 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto @@ -68,7 +68,8 @@ message ProfileRequest { } message ProfileToolData { - // The tool's name which this data is associated. (e.g. "input_pipeline".) + // The file name which this data is associated (e.g. "input_pipeline.json", + // "cluster_xxx.memory_viewer.json"). string name = 1; // The data payload (likely json) for the specific tool. diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto index 8b0bbde98e6a1dee8ade789328f3ba0624049562..d3c34bfd490080b86cf3d8b893c550f3a87bbbed 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto @@ -38,6 +38,9 @@ message EnumProfileSessionsAndToolsResponse { message ProfileSessionDataRequest { string repository_root = 1; string session_id = 2; + // Which host the data is associated. if empty, data from all hosts are + // aggregated. + string host_name = 5; // Which tool string tool_name = 3; // Tool's specific parameters. e.g. TraceViewer's viewport etc diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index 2e472a2805f98b15505f56af403aa6223e28c667..d879170b6875b3088d284459b70dc91567e33bab 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -166,11 +166,21 @@ def StreamingFilesDataset(files, return remote_iterator.get_next() def MapFn(unused_input): - return functional_ops.remote_call( + if isinstance(source_dataset.output_types, dtypes.DType): + output_types = [source_dataset.output_types] + elif isinstance(source_dataset.output_types, (list, tuple)): + output_types = source_dataset.output_types + else: + raise ValueError('source dataset has invalid output types') + remote_calls = functional_ops.remote_call( args=[source_handle], - Tout=[dtypes.string], + Tout=output_types, f=LoadingFunc, - target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)[0] + target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job) + if len(remote_calls) == 1: + return remote_calls[0] + else: + return remote_calls with ops.device('/job:%s' % worker_job): output_dataset = dataset_ops.Dataset.range(2).repeat().map( diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py index 918cf0ed8e513de0d4207f7d2aac61ad886c8288..b58d05eac56f3586e183333f7c1a3867ee57456c 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -26,6 +26,8 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -162,6 +164,30 @@ class DatasetsTest(test.TestCase): self.assertEqual(set(all_contents), set(retrieved_values)) + def testArbitraryReaderFuncFromDatasetGenerator(self): + + def my_generator(): + yield (1, [1] * 10) + + def gen_dataset(dummy): + return dataset_ops.Dataset.from_generator( + my_generator, (dtypes.int64, dtypes.int64), + (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10]))) + + dataset = datasets.StreamingFilesDataset( + dataset_ops.Dataset.range(10), filetype=gen_dataset) + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = self._sess.run(get_next) + + self.assertIsInstance(retrieved_values, (list, tuple)) + self.assertEqual(len(retrieved_values), 2) + self.assertEqual(retrieved_values[0], 1) + self.assertItemsEqual(retrieved_values[1], [1] * 10) + def testUnexpectedFiletypeString(self): with self.assertRaises(ValueError): datasets.StreamingFilesDataset( diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index b1d8d38a9a0e68719380d062a2357930b2f5f167..f1a11fa6548b87d6222a97c72b8db5442c8ef774 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -49,20 +49,23 @@ from __future__ import print_function import collections import re +import time from tensorflow.contrib.framework.python.framework import experimental +from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.contrib.tpu.python.tpu import tpu_optimizer from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as tf_session from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import layers -from tensorflow.python.keras._impl.keras import models -from tensorflow.python.keras._impl.keras import optimizers as keras_optimizers -from tensorflow.python.keras._impl.keras.layers import embeddings +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import layers +from tensorflow.python.keras import models +from tensorflow.python.keras import optimizers as keras_optimizers +from tensorflow.python.keras.layers import embeddings from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging @@ -75,9 +78,6 @@ class TPUEmbedding(embeddings.Embedding): replacement: it has the same behavior and will work on CPU and GPU devices. """ - def __init__(self, *args, **kw): - super(TPUEmbedding, self).__init__(*args, **kw) - def build(self, input_shape): if input_shape[0] is None: raise ValueError( @@ -92,10 +92,11 @@ class TPUEmbedding(embeddings.Embedding): return math_ops.tensordot(inputs, self.embeddings, 1) -class CompiledTPUOp( +class TPUModelOp( collections.namedtuple( - 'CompiledTPUOp', - ['tpu_execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op'])): + 'TPUModelOp', + ['compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', + 'outfeed_op'])): pass @@ -104,6 +105,15 @@ def _valid_name(tensor_name): return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name) +def _replicated_optimizer(opt, num_replicas): + """Wrap the optimizer `opt` with CrossShardOptimizer if applicable.""" + if num_replicas == 1: + return opt + return keras_optimizers.TFOptimizer( + optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer) + ) + + class TPUFunction(object): """K.function compatible interface for invoking a TPU compiled function. @@ -116,10 +126,11 @@ class TPUFunction(object): instead of being injected as `feed_dict` items or fetches. """ - def __init__(self, model, execution_mode): + def __init__(self, model, execution_mode, num_replicas=1): self.model = model self.execution_mode = execution_mode self._compilation_cache = {} + self.num_replicas = num_replicas def _specialize_model(self, input_specs): """Specialize `self.model` (a Keras model) for the given input shapes.""" @@ -165,9 +176,11 @@ class TPUFunction(object): # Call our model with our infeed inputs (re-using the weights). model_outputs = self.model(tpu_inputs) child_model = models.Model(inputs=tpu_inputs, outputs=model_outputs) + if is_training or is_test: child_model.compile( - optimizer=self.model.optimizer, + optimizer=_replicated_optimizer(self.model.optimizer, + self.num_replicas), loss=self.model.loss, loss_weights=self.model.loss_weights, metrics=self.model.metrics, @@ -185,7 +198,8 @@ class TPUFunction(object): return [ child_model.train_function.updates_op, tpu_ops.outfeed_enqueue_tuple( - child_model.train_function.outputs, name='oufeed-enqueue-train') + child_model.train_function.outputs, + name='outfeed-enqueue-train') ] elif is_test: child_model._make_test_function() @@ -195,7 +209,8 @@ class TPUFunction(object): ] return [ tpu_ops.outfeed_enqueue_tuple( - child_model.test_function.outputs, name='outfeed-enqueue-test') + child_model.test_function.outputs, + name='outfeed-enqueue-test') ] elif is_predict: child_model._make_predict_function() @@ -215,29 +230,85 @@ class TPUFunction(object): # Capture outfeed metadata computed during the rewrite. self._outfeed_spec = None - tpu_execute_op = tpu.rewrite(_model_fn) + # Generate out TPU operations using `tpu.split_compile_and_replicate`. + # `compile_op` can be used to test the TPU model compiles before execution. + # `execute op` replicates `_model_fn` `num_replicas` times, with each shard + # running on a different logical core. + compile_op, execute_op = tpu.split_compile_and_replicate( + _model_fn, inputs=[[]] * self.num_replicas) # Generate CPU side operations to enqueue features/labels and dequeue # outputs from the model call. - with ops.device('/device:TPU:0'): - infeed_tensors = [] - for spec in input_specs: - infeed_tensors.append( - array_ops.placeholder( - dtype=spec.dtype, - shape=spec.shape, - name='infeed-enqueue-%s' % spec.name)) - - infeed_op = tpu_ops.infeed_enqueue_tuple( - infeed_tensors, [spec.shape for spec in input_specs], - name='infeed-enqueue-%s' % self.execution_mode) - - outfeed_op = tpu_ops.outfeed_dequeue_tuple( - dtypes=[spec.dtype for spec in self._outfeed_spec], - shapes=[spec.shape for spec in self._outfeed_spec], - name='outfeed-dequeue-%s' % self.execution_mode) - - return CompiledTPUOp(tpu_execute_op, infeed_tensors, infeed_op, outfeed_op) + infeed_op = [] + outfeed_op = [] + shard_infeed_tensors = [] + + for shard_id in range(self.num_replicas): + with ops.device('/device:TPU:%d' % shard_id): + infeed_tensors = [] + for spec in input_specs: + infeed_tensors.append( + array_ops.placeholder( + dtype=spec.dtype, + shape=spec.shape, + name='infeed-enqueue-%s-%d' % (spec.name, shard_id))) + shard_infeed_tensors.append(infeed_tensors) + + infeed_op.append(tpu_ops.infeed_enqueue_tuple( + infeed_tensors, [spec.shape for spec in input_specs], + name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id))) + + outfeed_op.extend(tpu_ops.outfeed_dequeue_tuple( + dtypes=[spec.dtype for spec in self._outfeed_spec], + shapes=[spec.shape for spec in self._outfeed_spec], + name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id))) + + return TPUModelOp( + compile_op, execute_op, infeed_tensors=shard_infeed_tensors, + infeed_op=infeed_op, outfeed_op=outfeed_op) + + def _test_model_compiles(self, tpu_model_ops): + """Verifies that the given TPUModelOp can be compiled via XLA.""" + session = K.get_session() + + logging.info('Started compiling') + start_time = time.clock() + + result = session.run(tpu_model_ops.compile_op) + proto = tpu_compilation_result.CompilationResultProto() + proto.ParseFromString(result) + if proto.status_error_message: + raise RuntimeError( + 'Compilation failed: {}'.format(proto.status_error_message)) + + end_time = time.clock() + logging.info('Finished compiling. Time elapsed: %s secs', + end_time - start_time) + + def _split_tensors(self, inputs): + """Split input data across shards. + + Each input is sliced along the batch axis. + + Args: + inputs: List of Numpy arrays to run on the TPU. + + Returns: + List of lists containing the input to feed to each TPU shard. + """ + if self.num_replicas == 1: + return [inputs] + + batch_size = inputs[0].shape[0] + assert batch_size % self.num_replicas == 0, ( + 'batch_size must be divisible by num_replicas') + shard_size = batch_size // self.num_replicas + input_list = [] + for index in range(self.num_replicas): + shard_inputs = [x[index * shard_size:(index + 1) * shard_size] + for x in inputs] + input_list.append(shard_inputs) + return input_list def __call__(self, inputs): assert isinstance(inputs, list) @@ -250,12 +321,18 @@ class TPUFunction(object): else: input_tensors = self.model._feed_inputs + shard_inputs = self._split_tensors(inputs) + del inputs # To avoid accident usage. + # Compute an input specification (used to generate infeed enqueue and # dequeue operations). We use the shape from our input array and the # dtype from our model. A user may pass in a float64 for a float32 # input: for model compatibility we still must generate a float32 infeed. input_specs = [] - for tensor, ary in zip(input_tensors, inputs): + + # We use the shape and dtype from the first shard to compute the input + # metadata (`input_specs`); all replicas have the same type and shape. + for tensor, ary in zip(input_tensors, shard_inputs[0]): input_specs.append( tensor_spec.TensorSpec(ary.shape, tensor.dtype, _valid_name(tensor.name))) @@ -268,21 +345,26 @@ class TPUFunction(object): if shape_key not in self._compilation_cache: logging.info('New input shapes; (re-)compiling: mode=%s, %s', self.execution_mode, input_specs) - self._compilation_cache[shape_key] = self._specialize_model(input_specs) + new_tpu_model_ops = self._specialize_model(input_specs) + self._compilation_cache[shape_key] = new_tpu_model_ops + self._test_model_compiles(new_tpu_model_ops) - compiled_model = self._compilation_cache[shape_key] + tpu_model_ops = self._compilation_cache[shape_key] infeed_dict = {} - for tensor, value in zip(compiled_model.infeed_tensors, inputs): - infeed_dict[tensor] = value + for infeed_tensors, inputs in zip(tpu_model_ops.infeed_tensors, + shard_inputs): + for tensor, value in zip(infeed_tensors, inputs): + infeed_dict[tensor] = value session = K.get_session() _, _, outfeed_outputs = session.run([ - compiled_model.infeed_op, compiled_model.tpu_execute_op, - compiled_model.outfeed_op + tpu_model_ops.infeed_op, tpu_model_ops.execute_op, + tpu_model_ops.outfeed_op ], infeed_dict) - return outfeed_outputs + # TODO(xiejw): Decide how to reduce outputs, or just discard all but first. + return outfeed_outputs[:len(outfeed_outputs) // self.num_replicas] @experimental @@ -317,8 +399,8 @@ def shutdown_tpu_session(session=None): class KerasTPUModel(models.Model): """TPU compatible Keras model wrapper.""" - def __init__(self, inputs, outputs, name=None): - super(models.Model, self).__init__( + def __init__(self, inputs, outputs, name, replicas=1): + super(models.Model, self).__init__( # pylint: disable=bad-super-call inputs=inputs, outputs=outputs, name=name, @@ -326,6 +408,7 @@ class KerasTPUModel(models.Model): self.predict_function = None self.test_function = None self.train_function = None + self.replicas = replicas def compile(self, optimizer, @@ -354,7 +437,8 @@ class KerasTPUModel(models.Model): def _make_train_function(self): if not self.train_function: - self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN) + self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN, + num_replicas=self.replicas) return self.train_function @@ -420,7 +504,53 @@ Output shape: %(output_shape)s @experimental -def tpu_model(model): +def tpu_model(model, replicas=None): + """Runs a model on TPU(s). + + Usage: + ``` + a = Input(shape=(32,)) + b = Dense(32)(a) + model = Model(inputs=a, outputs=b) + + model = keras_support.tpu_model(model) + model.compile( + optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0), + ...) + ``` + + If `replicas` is set, replicates the model computation on all TPU cores. The + model computation is replicated `num_replicas` times; each shard will run on a + different TPU core. + + Limitation: Currently, replication is only supported for training. + + Usage: + ``` + a = Input(shape=(32,)) + b = Dense(32)(a) + model = Model(inputs=a, outputs=b) + + model = keras_support.tpu_model(model, replicas=2) + model.compile( + optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0), + ...) + ``` + + Args: + model: A `KerasTPUModel`. + replicas: (Optional) Int, number of TPU cores which to create model + replicas. If `None`, the model runs on single core only, i.e., no + replication. + + Returns: + A new `KerasTPUModel` instance. + """ _validate_shapes(model) + # TODO(xiejw): Validate TPU model. TPUModel only? + # TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset? + # TODO(xiejw): Adds reduction option. + replicas = 1 if replicas is None else replicas return KerasTPUModel( - inputs=model.inputs, outputs=model.outputs, name=model.name) + inputs=model.inputs, outputs=model.outputs, name=model.name, + replicas=replicas) diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index faf677a81d0827c9286808797181abe9c4b82c63..3e91e2df32e6f18b7f74c1d81f64776e59d09c2a 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -292,14 +292,21 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): if self._saver: return self._saver - savers = ops.get_collection(ops.GraphKeys.SAVERS)[0] + savers = ops.get_collection(ops.GraphKeys.SAVERS) if not savers: return None if not isinstance(savers, list): return savers - assert len(savers) == 1, 'Only one saver supported.' + if len(savers) > 1: + logging.error( + 'Multiple savers in the SAVERS collection. On-demand checkpointing ' + 'will be disabled. Pass an explicit `saver` to the constructor to ' + 'override this behavior.' + ) + return None + return savers[0] def after_run(self, run_context, run_values): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index c8f24ed01d13a1325ed3d77d1d91d4df79b0e379..1c482950e64a9537a2996df66ed9403e53cf8a71 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -21,6 +21,7 @@ from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function @@ -125,7 +126,19 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): outside the replicated computation. """ - def __init__(self, name, num_replicas): + def __init__(self, name, num_replicas, pivot): + """Builds a new TPUReplicateContext. + + Args: + name: a unique name for the context, used to populate the `_tpu_replicate` + attribute. + num_replicas: an integer that gives the number of replicas for the + computation. + pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any + inputs will have a control dependency on the pivot node. This ensures + that nodes are correctly included in any enclosing control flow + contexts. + """ super(TPUReplicateContext, self).__init__() self._num_replicas = num_replicas self._outer_device_function_stack = None @@ -137,6 +150,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._host_compute_core = [] self._name = name self._unsupported_ops = [] + self._pivot = pivot def report_unsupported_operations(self): if self._unsupported_ops: @@ -261,9 +275,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access super(TPUReplicateContext, self).Enter() - def Exit(self): - super(TPUReplicateContext, self).Exit() - def HostComputeCore(self): return self._host_compute_core @@ -299,10 +310,64 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) + # Remove any control edges from outer control flow contexts. These may cause + # mismatched frame errors. + control_inputs, external_inputs = self._RemoveExternalControlEdges(op) + + if not op.inputs: + # Add a control edge from the control pivot to this op. + if not control_inputs: + # pylint: disable=protected-access + op._add_control_input(self.GetControlPivot()) + # pylint: enable=protected-access + else: + for index in xrange(len(op.inputs)): + x = op.inputs[index] + real_x = self.AddValue(x) + if real_x != x: + op._update_input(index, real_x) # pylint: disable=protected-access + + if external_inputs: + # Use an identity to pull control inputs as data inputs. Note that we + # ignore ops which don't have outputs. TODO(phawkins): fix that. + with ops.control_dependencies(None): + self.Enter() + external_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_inputs + if x.outputs + ] + self.Exit() + # pylint: disable=protected-access + op._add_control_inputs(external_inputs) + # pylint: enable=protected-access + + # Mark op's outputs as seen by this context and any outer contexts. + output_names = [x.name for x in op.outputs] + context = self + while context is not None: + # pylint: disable=protected-access + context._values.update(output_names) + context = context._outer_context + # pylint: enable=protected-access + + if self._outer_context: + self._outer_context.AddInnerOp(op) + def AddValue(self, val): + if val.name in self._values: + # Use the real value if it comes from outer context. + result = self._external_values.get(val.name) + return val if result is None else result + result = val + self._values.add(val.name) if self._outer_context: result = self._outer_context.AddValue(val) + self._values.add(result.name) + + self._external_values[val.name] = result + return result def AddInnerOp(self, op): @@ -318,17 +383,30 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # grad_state should be as if this is the top-level gradient state. return None + @property + def back_prop(self): + """Forwards to the enclosing while context, if any.""" + if self.GetWhileContext(): + return self.GetWhileContext().back_prop + return False -def outside_compilation(computation, args=None): + def GetControlPivot(self): + return self._pivot + + +def outside_compilation(computation, *args, **kwargs): """Builds part of a computation outside any current TPU replicate scope. Args: computation: A Python function that builds the computation to place on the host. - args: Inputs to pass to computation. + *args: the positional arguments for the computation. + **kwargs: the keyword arguments for the computation. + Returns: The Tensors returned by computation. """ + args = [] if args is None else args graph = ops.get_default_graph() # If we are in a TPUReplicateContext, signal that we are now @@ -340,7 +418,7 @@ def outside_compilation(computation, args=None): context._EnterOutsideCompilationScope() # pylint: disable=protected-access context = context.outer_context - retval = computation(*args) + retval = computation(*args, **kwargs) # If we are in a TPUReplicateContext, signal that we are no longer # outside_compilation @@ -394,7 +472,8 @@ def split_compile_and_replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, - name=None): + name=None, + use_tpu=True): """Builds graph operators that runs compilation and replicated computation. This is a lower level interface than replicate that returns a separate compile @@ -417,6 +496,9 @@ def split_compile_and_replicate(computation, only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. + use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU + backends. Currently, only supports a default placement (computation is + placed on GPU if one is available, and on CPU if not). Returns: A list of lists with the first list corresponding to the compile op and the second a list of output tensors, indexed by `[replica_num][output_num]`. @@ -497,12 +579,14 @@ def split_compile_and_replicate(computation, tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") - context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas) + pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") + context = TPUReplicateContext( + name=cluster_name, num_replicas=num_replicas, pivot=pivot) try: context.Enter() metadata = tpu_ops.tpu_replicate_metadata( - num_replicas=num_replicas, **metadata_kwargs) + num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): @@ -539,6 +623,11 @@ def split_compile_and_replicate(computation, vscope.set_use_resource(saved_use_resource) + # If the computation returns `None`, add `no_op` here so that when user + # fetches `no_op` returned by this function, the TPUExecute node will be + # triggered. + if outputs is None: + outputs = (control_flow_ops.no_op(),) # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) @@ -574,6 +663,7 @@ def split_compile_and_replicate(computation, with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors + context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() @@ -590,10 +680,13 @@ def split_compile_and_replicate(computation, for i in xrange(output_arity)] with ops.control_dependencies([metadata]): - compile_status = tpu_ops.tpu_compilation_result() - op = compile_status.op - attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) - op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access + if use_tpu: + compile_status = tpu_ops.tpu_compilation_result() + op = compile_status.op + attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) + op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access + else: + compile_status = control_flow_ops.no_op(name="compilation_status") with ops.control_dependencies(output_operations): if output_arity == 0: @@ -860,3 +953,152 @@ def rewrite(computation, device_assignment=device_assignment, name=name)[0] # pylint: enable=indexing-exception + + # Operations that indicate some error in the user's inference graph. +_BLACKLISTED_INFERENCE_OPS = set([ + "ReadVariableOp", + "AssignVariableOp", + "AssignAddVariableOp", + "AssignSubVariableOp", + "VarHandleOp", + "Variable", + "VariableV2", +]) + + +class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): + """A `ControlFlowContext` for nodes inside a TPU inference computation. + + The primary role of `TPUReplicateContext` is to sanity check operators inside + a tpu.rewrite_for_inference() computation. + """ + + def __init__(self, name): + super(_TPUInferenceContext, self).__init__() + self._name = name + + def AddOp(self, op): + self._AddOpInternal(op) + + def _AddOpInternal(self, op): + # pylint: disable=protected-access + if op.type in _BLACKLISTED_INFERENCE_OPS: + raise NotImplementedError( + "Operation of type %s (%s) is not supported on the TPU for inference." + " Execution will fail if this op is used in the graph. Make sure your" + " variables are using variable_scope." % (op.type, op.name)) + if self._outer_context: + self._outer_context.AddInnerOp(op) + + def AddValue(self, val): + result = val + if self._outer_context: + result = self._outer_context.AddValue(val) + return result + + def AddInnerOp(self, op): + self._AddOpInternal(op) + + @property + def grad_state(self): + return None + + +@experimental +def validate_inference_rewrite_for_variables(graph): + """Validates whether rewrite_for_inference() 'worked' for variables. + + The rewrite_for_inference() method is supposed to append + GuaranteeConstOps after ReadVariableOps, but this mechanism works only + if you are using tf.get_variable() to create and access variables in your + tpu computation. This validation method can be called immediately after + calling tpu.rewrite_for_inference() to check whether GuaranteeConstOps + where added to the graph. + + Typical usages: + tpu.validate_inference_rewrite_for_variables(tf.get_default_graph()) + + tpu.validate_inference_rewrite_for_variables(sess.graph) + + Args: + graph: The graph which needs to be validated. + Raises: + RuntimeError: if validation failed. + """ + if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]): + raise RuntimeError( + "No GuaranteeConst ops found in the graph after " + "running tpu.rewrite_for_inference(...). Please " + "check that you are using tf.get_variable() to " + "create and access variables in your tpu " + "computation.") + + +@experimental +def rewrite_for_inference(computation, + inputs=None, + infeed_queue=None, + device_assignment=None, + name=None): + """Rewrites `computation` for inference on a TPU system. + + Other than 'rewriting' the computation to run on a TPU, if using variables + in your computation, it moves the ReadVariableOps outside the TPU + computation, and adds GuaranteeConst ops just after the ReadVariableOps. + This mechanism works only if you are using tf.get_variable() to create and + access variables in your tpu computation. You can validate whether + this worked, by calling validate_inference_rewrite_for_variables() method + immediately after this method to check whether GuaranteeConstOps where + added to the graph. + + Args: + computation: A Python function that builds a computation to apply + to the input. If the function takes n inputs, 'inputs' should be + a list of n tensors. If the function returns m outputs, rewrite + will return a list of m tensors. + inputs: A list of input tensors or `None` (equivalent to an empty list). + infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple + of arguments as inputs to `computation`. + device_assignment: if not `None`, a `DeviceAssignment` describing the + mapping between logical cores in the computation with physical cores in + the TPU topology. May be omitted for a single-core computation, in which + case the core attached to task 0, TPU device 0 is used. + name: The name of the operator. + Returns: + A list of output tensors. + """ + + def guarantee_const_getter(getter, name, *args, **kwargs): + with ops.control_dependencies(None): + return array_ops.guarantee_const( + getter(name, *args, **kwargs), name=name + "/GuaranteeConst") + + def wrapped_computation(*args, **kwargs): + """Execute computation under `_TPUInferenceContext`.""" + context = _TPUInferenceContext( + name=ops.get_default_graph().unique_name("rewrite_for_inference")) + try: + context.Enter() + + vscope = variable_scope.get_variable_scope() + prev_custom_getter = vscope.custom_getter + prev_caching_device = vscope.caching_device + vscope.set_custom_getter(guarantee_const_getter) + vscope.set_caching_device(lambda op: op.device) + + result = computation(*args, **kwargs) + + vscope.set_custom_getter(prev_custom_getter) + vscope.set_caching_device(prev_caching_device) + finally: + context.Exit() + return result + + # pylint: disable=undefined-variable + return rewrite( + wrapped_computation, + inputs=inputs, + infeed_queue=infeed_queue, + device_assignment=device_assignment, + name=name) + # pylint: enable=undefined-variable diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 4d7bc6a5a65eaae9810b7f01bdb96b3537ba9896..5b9aeaa8797b92b4cc596744812f440607054dce 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -35,7 +35,98 @@ _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' _LOCAL_MASTERS = ('', 'local') -class _TPUContext(object): +class TPUContext(object): + """The context of current input_fn invocation.""" + + def __init__(self, internal_ctx, input_device=None, invocation_index=None): + self._internal_ctx = internal_ctx + self._input_device = input_device + self._invocation_index = invocation_index + + def current_input_fn_deployment(self): + """The configuration of the current input_fn invocation. + + The configuration depends on `TPUConfig.per_host_input_for_training`. See + `TPUConfig` for details. + + Only set in params dict of input_fn + + Returns: + A tuple of + 1. Device spec string: String, is the current CPU host where the + input_fn is invoked. + 2. Current invocation index: Int, 0-based index of the input_fn + invocation. See next item for details. + 3. Total invocation count: Int, the total number of times to invoke the + input_fn on all CPU hosts. Each invocation will be passed with a new + `TPUContext` instance with current invocation index set properly. + 4. Total number of replicas consumed by current_invocation: Int, the + number of replicas fed by the data returned by current input_fn. For + example, for per_core input pipeline deployment + and non-model-parallelism, total invocation count is equal to + the number of cores in the system and num replicas consumed by + current invocation is 1. For per-host v2 input pipeline deployment, + total invocation count is equal to the number of hosts in the system + and num replicas consumed by current invocation is equal to number of + cores per host. + """ + if self._internal_ctx.is_input_sharded_per_core(): + total_invocation_count = (self._internal_ctx.num_hosts + * self._internal_ctx.num_of_replicas_per_host) + replicas_consumed = 1 + else: + total_invocation_count = self._internal_ctx.num_hosts + replicas_consumed = self._internal_ctx.num_of_replicas_per_host + return (self._input_device, self._invocation_index, + total_invocation_count, replicas_consumed) + + @property + def num_replicas(self): + """The total number of replicas. + + For non-model-parallelism, num_replicas should be the total num of TPU + cores in the system. + + Returns: + The number of replicas. + """ + return self._internal_ctx.num_replicas + + def device_for_replica(self, replica_id): + """Returns the tuple of (CPU device and device ordinal) for replica. + + This should be used for full replicate for non-model-parallelism. + + Args: + replica_id: Int, the replica index. + + Returns: + A tuple of device spec for CPU device and int device ordinal. + """ + # Note that: For the non-model parallelism, the mapping could be + # a random permutation. The order should not matter in most cases + # as far as model is replicated to all cores in the system. + + # If the precise replica_id to device mapping is required, please + # set the computation_shape as [1,1,1] in TPUConfig to enable + # the model parallelism. + if self._internal_ctx.model_parallelism_enabled: + return RuntimeError( + 'device_for_replica is not yet implemented for model parallelism. ' + 'b/79689078.') + + master = self._internal_ctx.master_job + job_device = '' if master is None else ('/job:%s' % master) + + num_of_replicas_per_host = self._internal_ctx.num_of_replicas_per_host + host_id = replica_id / num_of_replicas_per_host + ordinal_id = replica_id % num_of_replicas_per_host + + host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id) + return (host_device, ordinal_id) + + +class _InternalTPUContext(object): """A context holds immutable states of TPU computation. This immutable object holds TPUEstimator config, train/eval batch size, and @@ -44,9 +135,13 @@ class _TPUContext(object): information commonly required by TPU computation, such as TPU device names, TPU hosts, shard batch size, etc. + if eval_on_tpu is False, then execution of eval on TPU is disabled. + if eval_on_tpu is True, but use_tpu is False, a warning is issued, + and TPU execution is disabled for all modes. + N.B. As `mode` is not immutable state in Estimator, but essential to distinguish between TPU training and evaluation, a common usage for - _TPUContext with `mode` is as follows: + _InternalTPUContext with `mode` is as follows: ``` with _ctx.with_mode(mode) as ctx: if ctx.is_running_on_cpu(): @@ -55,12 +150,17 @@ class _TPUContext(object): """ def __init__(self, config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu): + predict_batch_size, use_tpu, eval_on_tpu=True): self._config = config self._train_batch_size = train_batch_size self._eval_batch_size = eval_batch_size self._predict_batch_size = predict_batch_size self._use_tpu = use_tpu + logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu) + if not use_tpu and eval_on_tpu: + logging.warning('eval_on_tpu ignored because use_tpu is False.') + + self._eval_on_tpu = eval_on_tpu self._model_parallelism_enabled = ( use_tpu and config.tpu_config.computation_shape) self._mode = None @@ -246,6 +346,10 @@ class _TPUContext(object): if not self._use_tpu: return True + if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu: + logging.info('_is_running_on_cpu: eval_on_tpu disabled') + return True + if mode != model_fn_lib.ModeKeys.PREDICT: return False @@ -345,6 +449,7 @@ class _TPUContext(object): @property def tpu_host_placement_function(self): """Returns the TPU host place function.""" + master = self.master_job def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name @@ -473,8 +578,8 @@ class _TPUContext(object): self._lazy_validation_dict[mode] = True -class _OneCoreTPUContext(_TPUContext): - """Special _TPUContext for one core usage.""" +class _OneCoreTPUContext(_InternalTPUContext): + """Special _InternalTPUContext for one core usage.""" def __init__(self, config, train_batch_size, eval_batch_size, predict_batch_size, use_tpu): @@ -503,8 +608,8 @@ class _OneCoreTPUContext(_TPUContext): def _get_tpu_context(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu): - """Returns an instance of `_TPUContext`.""" + predict_batch_size, use_tpu, eval_on_tpu): + """Returns an instance of `_InternalTPUContext`.""" if (config.tpu_config.num_shards == 1 and config.tpu_config.computation_shape is None): @@ -514,5 +619,5 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size, return _OneCoreTPUContext(config, train_batch_size, eval_batch_size, predict_batch_size, use_tpu) - return _TPUContext(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu) + return _InternalTPUContext(config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu, eval_on_tpu) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index a624eceed9a65c8e9dc7baf056383eba06c5a414..64ae35dfc5e6d385a23c2dba15562d71aae4d497 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -46,7 +46,8 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util +from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -62,26 +63,31 @@ from tensorflow.python.ops import summary_ops_v2 as contrib_summary from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import evaluation from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.training import training_util +from tensorflow.python.util import function_utils from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect + _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. _TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' +_CTX_KEY = 'context' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' +_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' -_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] +_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] # TODO(b/65703635): Flip the value and remove all dead code. Currently, this is @@ -116,6 +122,33 @@ def _create_global_step(graph): def _create_or_get_iterations_per_loop(): + """Creates or gets the iterations_per_loop variable. + + In TPUEstimator, the user provided computation, the model_fn, is wrapped + inside a tf.while_loop for peak performance. The iterations of the loop are + specified by this variable, which adjusts its value on the CPU after each TPU + program execution and before the next TPU execution. + + The purpose of using a variable, rather then a constant, is to allow + TPUEstimator adapt the TPU training iterations according to the final steps + specified by users. For example, if the user sets the iterations_per_loop as 4 + in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop + variable will have the following value before each TPU training. + + - 1-th TPU execution: iterations_per_loop = 4 + - 2-th TPU execution: iterations_per_loop = 4 + - 3-th TPU execution: iterations_per_loop = 2 + + As model_fn increases the global step once per train_op invocation, the global + step is 10 after all TPU executions, matching the steps=10 inputs passed in by + users. + + Returns: + A TF non-trainable resource variable. + + Raises: + RuntimeError: If multi iterations_per_loop variables were found. + """ graph = ops.get_default_graph() collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) iter_vars = graph.get_collection(collection_name) @@ -382,20 +415,21 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): return def _cancel_session(): - # Close the session to avoid the main thread from hanging. If input - # pipeline triggers any error, the infeed thread dies but the main thread - # for TPU computation waits for the infeed enqueue forever. Close the - # Session to cancel the main thread Session.run execution. - # - # We sleep for a few seconds before closing to give some time - # for the TPU compilation error, if any, propagating, from TPU to CPU - # host. Compilation errors should be reported by the main thread so that - # the program can be interrupted and users can take action. Due to a race - # condition, the infeed thread might see an error first. Closing the - # session here immediately would result in a session cancellation - # exception in the main thread, instead of the expected compile error. - # User code that depends on having the proper exception type will - # therefore be confused. + """Close the session to avoid the main thread from hanging. + + If input pipeline triggers any error, the infeed thread dies but the main + thread for TPU computation waits for the infeed enqueue forever. Close the + Session to cancel the main thread Session.run execution. + + We sleep for a few seconds before closing to give some time for the TPU + compilation error, if any, propagating, from TPU to CPU host. Compilation + errors should be reported by the main thread so that the program can be + interrupted and users can take action. Due to a race condition, the + infeed thread might see an error first. Closing the session here + immediately would result in a session cancellation exception in the main + thread, instead of the expected compile error. User code that depends on + having the proper exception type will therefore be confused. + """ time.sleep(5) # If the main session is still running, the infeed/outfeed errors are @@ -626,8 +660,8 @@ class _StoppingPredictHook(session_run_hook.SessionRunHook): raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') -def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn, - inputs_structure_recorder): +def generate_per_core_enqueue_ops_fn_for_host( + ctx, input_fn, inputs_structure_recorder, host_device, host_id): """Generates infeed enqueue ops for per-core input_fn on a single host.""" captured_infeed_queue = _CapturedObject() @@ -637,7 +671,12 @@ def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn, per_host_sharded_inputs = [] for core_ordinal in range(num_cores_per_host): with ops.name_scope('ordinal_%d' % (core_ordinal)): - inputs = _Inputs.from_input_fn(input_fn()) + user_context = tpu_context.TPUContext( + internal_ctx=ctx, + input_device=host_device, + invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal + ) + inputs = _Inputs.from_input_fn(input_fn(user_context)) if inputs.is_dataset: raise TypeError( '`input_fn` returning `Dataset` is not yet supported in ' @@ -674,7 +713,11 @@ def generate_per_host_enqueue_ops_fn_for_host( hooks = [] with ops.device(device): - inputs = _Inputs.from_input_fn(input_fn()) + user_context = tpu_context.TPUContext( + internal_ctx=ctx, + input_device=device, + invocation_index=host_id) + inputs = _Inputs.from_input_fn(input_fn(user_context)) is_dataset = inputs.is_dataset if ctx.mode == model_fn_lib.ModeKeys.PREDICT: @@ -692,7 +735,7 @@ def generate_per_host_enqueue_ops_fn_for_host( hooks.append(inputs.dataset_initializer_hook()) # TODO(ylc): Refactoring the code to merge the tpu ordinal logic here and the - # _TPUContext.tpu_ordinal_function. We should either introduce another + # _InternalTPUContext.tpu_ordinal_function. We should either introduce another # abstraction or a different helper method. def _tpu_ordinal_function_impl(shard_index_in_host): # We put both enqueue/dequeue op at tpu.core(0) in each replica. @@ -706,6 +749,15 @@ def generate_per_host_enqueue_ops_fn_for_host( tpu_ordinal_function = None def enqueue_ops_fn(): + """A Fn returning the TPU infeed enqueue ops. + + By providing as a Fn, it can be invoked inside the tf.while_loop such that + the input pipeline for multiple iterations can be executed by one + Session.run call. + + Returns: + list of dict of ops. + """ with ops.device(device): num_of_replicas_per_host = ctx.num_of_replicas_per_host # Convert user input to features and labels. If the user returns a @@ -745,12 +797,15 @@ def generate_per_host_enqueue_ops_fn_for_host( def generate_per_host_v2_enqueue_ops_fn_for_host( ctx, input_fn, inputs_structure_recorder, device, host_id): """Generates infeed enqueue ops for per-host input_fn on a single host.""" - del host_id # unused captured_infeed_queue = _CapturedObject() hooks = [] with ops.device(device): - inputs = _Inputs.from_input_fn(input_fn()) + user_context = tpu_context.TPUContext( + internal_ctx=ctx, + input_device=device, + invocation_index=host_id) + inputs = _Inputs.from_input_fn(input_fn(user_context)) is_dataset = inputs.is_dataset if not is_dataset: @@ -801,13 +856,14 @@ class _InputPipeline(object): """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from - call site. To be precise, based on the configuration in `_TPUContext`, it - invokes `input_fn` for all cores (usually multi-host TPU training) or for one - host (usually for single-host TPU evaluation), and sends all `features` and - `labels` returned by `input_fn` to TPU infeed. For per-core invocation, - `features` and `labels` are piped to infeed directly, one tuple for each - core. For per-host invocation, `features` and `labels` are split at host - (with respect to `batch_axis`) and piped to all cores accordingly. + call site. To be precise, based on the configuration in + `_InternalTPUContext`, it invokes `input_fn` for all cores (usually + multi-host TPU training) or for one host (usually for single-host TPU + evaluation), and sends all `features` and `labels` returned by `input_fn` to + TPU infeed. For per-core invocation, `features` and `labels` are piped to + infeed directly, one tuple for each core. For per-host invocation, `features` + and `labels` are split at host (with respect to `batch_axis`) and piped to all + cores accordingly. In addition, flatten/unflatten are handled by `_InputPipeline` also. Model inputs returned by the `input_fn` can have one of the following forms: @@ -960,7 +1016,7 @@ class _InputPipeline(object): batch_axis: A python tuple of int values describing how each tensor produced by the Estimator `input_fn` should be split across the TPU compute shards. - ctx: A `_TPUContext` instance with mode. + ctx: A `_InternalTPUContext` instance with mode. Raises: ValueError: If both `sharded_features` and `num_cores` are `None`. @@ -1015,7 +1071,8 @@ class _InputPipeline(object): with ops.name_scope('input_pipeline_task%d' % (host_id)): enqueue_ops_fn, captured_infeed_queue = ( generate_per_core_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, self._inputs_structure_recorder)) + self._ctx, self._input_fn, self._inputs_structure_recorder, + host_device, host_id)) if _WRAP_INPUT_FN_INTO_WHILE_LOOP: run_infeed_loop_on_coordinator = False @@ -1075,10 +1132,16 @@ class _InputPipeline(object): return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator def _validate_input_pipeline(self): - # Perform some sanity checks to log user friendly information. We should - # error out to give users better error message. But, if - # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break - # user code, so, log a warning. + """Validates the input pipeline. + + Perform some sanity checks to log user friendly information. We should + error out to give users better error message. But, if + _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break + user code, so, log a warning. + + Raises: + RuntimeError: If the validation failed. + """ if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): err_msg = ('Input pipeline contains one or more QueueRunners. ' 'It could be slow and not scalable. Please consider ' @@ -1249,13 +1312,11 @@ class _ModelFnWrapper(object): 'estimator_spec used by TPU prediction must have type' '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) + self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) + captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) to_record = {} identity_fn = lambda **kwargs: kwargs - # TODO(xiejw): Adds validation for prediction dictionrary. - # TODO(xiejw): Adds support for single tensor as predictions. - if not isinstance(tpu_estimator_spec.predictions, dict): - raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] to_record['signals'] = [identity_fn, stopping_signals] if tpu_estimator_spec.host_call is not None: @@ -1267,9 +1328,24 @@ class _ModelFnWrapper(object): return predict_step, host_calls, captured_scaffold_fn + def _verify_tpu_spec_predictions(self, predictions): + """Validates TPUEstimatorSpec.predictions dict.""" + # TODO(xiejw): Adds validation for prediction dictionrary. + # TODO(xiejw): Adds support for single tensor as predictions. + if not isinstance(predictions, dict): + raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') + + for (key, tensor) in predictions.items(): + if tensor.shape[0].value is None: + raise ValueError( + 'The tensor with key ({}) in TPUEstimatorSpec.predictions has ' + 'dynamic shape (should be static). Tensor: {}'.format( + key, tensor)) + return predictions + def _call_model_fn(self, features, labels, is_export_mode=False): """Calls the model_fn with required parameters.""" - model_fn_args = util.fn_args(self._model_fn) + model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} # Makes deep copy with `config` and params` in case user mutates them. @@ -1299,10 +1375,7 @@ class _ModelFnWrapper(object): batch_size_for_model_fn = self._ctx.batch_size_for_model_fn if batch_size_for_model_fn is not None: - if isinstance(params, hparam.HParams): - params.add_hparam(_BATCH_SIZE_KEY, batch_size_for_model_fn) - else: - params[_BATCH_SIZE_KEY] = batch_size_for_model_fn + _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) estimator_spec = self._model_fn(features=features, **kwargs) if (self._ctx.is_running_on_cpu(is_export_mode) and @@ -1361,7 +1434,7 @@ class _OutfeedHostCall(object): if isinstance(host_call[1], (tuple, list)): fullargspec = tf_inspect.getfullargspec(host_call[0]) - fn_args = util.fn_args(host_call[0]) + fn_args = function_utils.fn_args(host_call[0]) # wrapped_hostcall_with_global_step uses varargs, so we allow that. if fullargspec.varargs is None and len(host_call[1]) != len(fn_args): raise RuntimeError( @@ -1612,7 +1685,9 @@ class TPUEstimator(estimator_lib.Estimator): ========== `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics` - for TPU evaluation. + for TPU evaluation. However, if eval_on_tpu is False, `model_fn` must return + `EstimatorSpec` and the evaluation will execute on CPU or GPU; in this case + the following discussion on TPU evaluation does not apply. `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See @@ -1746,8 +1821,45 @@ class TPUEstimator(estimator_lib.Estimator): Exporting ========= - Exporting `SavedModel` support on TPU is not yet implemented. So, - `export_savedmodel` is executed on CPU, even if `use_tpu` is true. + `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`, + and another with `tag_constants.SERVING` and `tag_constants.TPU`. + At serving time, these tags are used to select metagraph to load. + + Before running the graph on TPU, TPU system needs to be initialized. If + TensorFlow Serving model-server is used, this is done automatically. If + not, please call `session.run(tpu.initialize_system())`. + + `tpu.outside_compilation` can be used to wrap TPU incompatible ops in + `model_fn`. + + Example: + ---------------- + + ``` + def model_fn(features, labels, mode, config, params): + ... + logits = ... + export_outputs = { + 'logits': export_output_lib.PredictOutput( + {'logits': logits}) + } + + def host_call(logits): + class_ids = math_ops.argmax(logits) + classes = string_ops.as_string(class_ids) + export_outputs['classes'] = + export_output_lib.ClassificationOutput(classes=classes) + + tpu.outside_compilation(host_call, logits) + + ... + ``` + + Current limitations: + -------------------- + + 1. Outside compilation does not work yet (b/79991729). + """ def __init__(self, @@ -1759,13 +1871,17 @@ class TPUEstimator(estimator_lib.Estimator): train_batch_size=None, eval_batch_size=None, predict_batch_size=None, - batch_axis=None): + batch_axis=None, + eval_on_tpu=True, + export_to_tpu=True, + warm_start_from=None): """Constructs an `TPUEstimator` instance. Args: model_fn: Model function as required by `Estimator`. For training, the returned `EstimatorSpec` cannot have hooks as it is not supported in - `TPUEstimator`. + `TPUEstimator`. Instead, the user can pass the training hooks as + an argument to `TPUEstimator.train()`. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `None`, the model_dir in @@ -1777,7 +1893,8 @@ class TPUEstimator(estimator_lib.Estimator): basic python types. There are reserved keys for `TPUEstimator`, including 'batch_size'. use_tpu: A bool indicating whether TPU support is enabled. Currently, - - TPU training and evaluation respect this bit. + - TPU training and evaluation respect this bit, but eval_on_tpu can + override execution of eval. See below. - Predict still happens on CPU. train_batch_size: An int representing the global training batch size. TPUEstimator transforms this global batch size to a per-shard batch @@ -1798,6 +1915,16 @@ class TPUEstimator(estimator_lib.Estimator): and per_host_input_for_training is True, batches will be sharded based on the major dimension. If tpu_config.per_host_input_for_training is False or `PER_HOST_V2`, batch_axis is ignored. + eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the + model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. + export_to_tpu: If True, `export_savedmodel()` exports a metagraph for + serving on TPU besides the one on CPU. + warm_start_from: Optional string filepath to a checkpoint or SavedModel to + warm-start from, or a `tf.estimator.WarmStartSettings` + object to fully configure warm-starting. If the string + filepath is provided instead of a `WarmStartSettings`, + then all variables are warm-started, and it is assumed + that vocabularies and Tensor names are unchanged. Raises: ValueError: `params` has reserved keys already. @@ -1812,7 +1939,7 @@ class TPUEstimator(estimator_lib.Estimator): if use_tpu: # Perform some very basic validations. More validations will be found in - # _TPUContext. + # _InternalTPUContext. if train_batch_size is None: raise ValueError('`train_batch_size` cannot be `None`') util_lib.check_positive_integer(train_batch_size, 'train_batch_size') @@ -1850,19 +1977,122 @@ class TPUEstimator(estimator_lib.Estimator): model_fn=model_function, model_dir=model_dir, config=config, - params=params) + params=params, + warm_start_from=warm_start_from) self._iterations_per_training_loop = ( self._config.tpu_config.iterations_per_loop) - # All properties passed to _TPUContext are immutable. + # All properties passed to _InternalTPUContext are immutable. # pylint: disable=protected-access self._ctx = tpu_context._get_tpu_context( self._config, train_batch_size, eval_batch_size, predict_batch_size, - use_tpu) + use_tpu, + eval_on_tpu) + + self._export_to_tpu = export_to_tpu self._is_input_fn_invoked = None + def _add_meta_graph_for_mode(self, + builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables=True, + mode=model_fn_lib.ModeKeys.PREDICT, + export_tags=None): + if mode != model_fn_lib.ModeKeys.PREDICT: + raise NotImplementedError( + 'TPUEstimator only handles mode PREDICT for export_savedmodel(); ' + 'got {}.'.format(mode)) + + super(TPUEstimator, self)._add_meta_graph_for_mode(builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables, + mode=mode) + + if self._export_to_tpu: + input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE: + input_receiver_fn_map[mode]} + export_tags = [tag_constants.SERVING, tag_constants.TPU] + mode = _REWRITE_FOR_INFERENCE_MODE + (super(TPUEstimator, self). + _add_meta_graph_for_mode(builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables=False, + mode=mode, + export_tags=export_tags)) + + def _call_model_fn(self, features, labels, mode, config): + if mode == _REWRITE_FOR_INFERENCE_MODE: + return self._call_model_fn_for_inference(features, labels, mode, config) + else: + return super(TPUEstimator, self)._call_model_fn( + features, labels, mode, config) + + def _call_model_fn_for_inference(self, features, labels, mode, config): + """Wraps `_call_model_fn` for `export_savedmodel`.""" + if mode != _REWRITE_FOR_INFERENCE_MODE: + raise ValueError('mode must be {}; ' + 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) + + capture = _CapturedObject() + + def computation(): + """Compute tpu tensors used in export_outputs. + + Passed to rewrite_for_inference so that model_fn will be called under + the rewriting contexts. Only tpu tensors are returned, but export_outputs + and scaffold are captured. + + Returns: + A list of Tensors used in export_outputs and not marked for + outside_compilation. + """ + # We should only call model fn once and it should be inside `computation` + # so that building the graph will happen under `rewrite_for_inference`. + mode = model_fn_lib.ModeKeys.PREDICT + estimator_spec = self._call_model_fn(features, labels, mode, config) + + # We pick the TPU tensors out from `export_output` and later return them + # from `computation` for rewriting. + tensors_dict = collections.OrderedDict( + (k, _export_output_to_tensors(v)) + for k, v in six.iteritems(estimator_spec.export_outputs) + ) + tensors = nest.flatten(tensors_dict) + tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)] + + # We cannot return anything other than `tpu_tensors` here so we capture + # the rest for later use. + capture.capture((estimator_spec, tensors_dict, tensors)) + return tpu_tensors + + tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation) + estimator_spec, tensors_dict, tensors = capture.get() + + # Reconstruct `tensors`, but with `tpu_tensors` replaced with + # `tpu_tensors_on_cpu`. + new_tensors = [ + tpu_tensors_on_cpu.pop(0) if _is_tpu_tensor(t) else t + for t in tensors + ] + # Reconstruct `tensors_dict`. + new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) + # Reconstruct `export_outputs`. + export_outputs = estimator_spec.export_outputs + new_export_outputs = collections.OrderedDict( + (k, _clone_export_output_with_tensors(export_outputs[k], v)) + for k, v in six.iteritems(new_tensors_dict) + ) + + return estimator_spec._replace(export_outputs=new_export_outputs) + def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -1930,7 +2160,7 @@ class TPUEstimator(estimator_lib.Estimator): Raises: ValueError: if input_fn takes invalid arguments or does not have `params`. """ - input_fn_args = util.fn_args(input_fn) + input_fn_args = function_utils.fn_args(input_fn) config = self.config # a deep copy. kwargs = {} if 'params' in input_fn_args: @@ -1953,10 +2183,8 @@ class TPUEstimator(estimator_lib.Estimator): # input_fn for use_tpu=True/False. batch_size_for_input_fn = ctx.batch_size_for_input_fn if batch_size_for_input_fn is not None: - if isinstance(kwargs['params'], hparam.HParams): - kwargs['params'].add_hparam(_BATCH_SIZE_KEY, batch_size_for_input_fn) - else: - kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn + _add_item_to_params(kwargs['params'], + _BATCH_SIZE_KEY, batch_size_for_input_fn) # For export_savedmodel, input_fn is never passed to Estimator. So, # `is_export_mode` must be False. @@ -1974,7 +2202,8 @@ class TPUEstimator(estimator_lib.Estimator): # tf.while_loop also. So, we either pass input_fn to model_fn or pass # dequeue_fn to model_fn. Here, `input_fn` is passed directly as # `features` in `model_fn` signature. - def _input_fn(): + def _input_fn(ctx): + _add_item_to_params(kwargs['params'], _CTX_KEY, ctx) return input_fn(**kwargs) return _input_fn @@ -2046,11 +2275,11 @@ class TPUEstimator(estimator_lib.Estimator): if shutdown_mode: if shutdown_mode == 'shutdown_worker': finalizer_hooks = [ - session_support.ShutdownLameWorkers(timeout_ms=1000), + session_support.ShutdownLameWorkers(timeout_ms=60*1000), ] elif shutdown_mode == 'shutdown_computation': finalizer_hooks = [ - session_support.RestartComputation(timeout_ms=1000), + session_support.RestartComputation(timeout_ms=60*1000), ] else: raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % @@ -2239,6 +2468,76 @@ class TPUEstimator(estimator_lib.Estimator): return _model_fn +def _is_tpu_tensor(tensor): + if not isinstance(tensor, ops.Tensor): + return False + try: + tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access + except ValueError: + return True + else: + return False + + +def _export_output_to_tensors(export_output): + """Get a list of `Tensors` used in `export_output`. + + Args: + export_output: an `ExportOutput` object such as `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + Returns: + a list of tensors used in export_output. + + Raises: + ValueError: if `export_output` is not one of `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + """ + if isinstance(export_output, export_output_lib.ClassificationOutput): + return [export_output.scores, export_output.classes] + elif isinstance(export_output, export_output_lib.RegressionOutput): + return [export_output.value] + elif isinstance(export_output, export_output_lib.PredictOutput): + return export_output.outputs.values() + else: + raise ValueError( + '`export_output` must be have type `ClassificationOutput`, ' + '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) + + +def _clone_export_output_with_tensors(export_output, tensors): + """Clones `export_output` but with new `tensors`. + + Args: + export_output: an `ExportOutput` object such as `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + tensors: a list of `Tensors` used to construct a new `export_output`. + + Returns: + A dict similar to `export_output` but with `tensors`. + + Raises: + ValueError: if `export_output` is not one of `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + """ + if isinstance(export_output, export_output_lib.ClassificationOutput): + if len(tensors) != 2: + raise ValueError('tensors must be of length 2; ' + 'got {}.'.format(len(tensors))) + return export_output_lib.ClassificationOutput(*tensors) + elif isinstance(export_output, export_output_lib.RegressionOutput): + if len(tensors) != 1: + raise ValueError('tensors must be of length 1; ' + 'got {}'.format(len(tensors))) + return export_output_lib.RegressionOutput(*tensors) + elif isinstance(export_output, export_output_lib.PredictOutput): + return export_output_lib.PredictOutput( + dict(zip(export_output.outputs.keys(), tensors))) + else: + raise ValueError( + '`export_output` must be have type `ClassificationOutput`, ' + '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) + + def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() @@ -2386,7 +2685,7 @@ class _CapturedObject(object): def capture(self, o): if self._captured: raise RuntimeError( - 'InternalError: Object can be captured only. Please file bug .') + 'InternalError: Object can capture only once. Please file bug.') self._captured = True self._object = o @@ -2395,7 +2694,7 @@ class _CapturedObject(object): if not self._captured: raise RuntimeError( 'InternalError: Object is not captured properly before `get`. ' - 'Please file bug .') + 'Please file bug.') return self._object @@ -2496,7 +2795,8 @@ class _Inputs(object): """ iterator = self._dataset.make_initializable_iterator() # pylint: disable=protected-access - hook = estimator_lib._DatasetInitializerHook(iterator) + hook = estimator_util._DatasetInitializerHook(iterator) + # pylint: enable=protected-access self._iterator = iterator return hook @@ -2642,6 +2942,7 @@ class _StopSignals(object): @staticmethod def should_stop(scalar_stopping_signal): + """Detects whether scalar_stopping_signal indicates stopping.""" if isinstance(scalar_stopping_signal, ops.Tensor): # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF # way to express the bool check whether scalar_stopping_signal is True. @@ -2792,3 +3093,16 @@ def _verify_cross_hosts_transfer_size(tensor_dict, message): '{}'.format(message, '\n'.join([ ' -- Key: {}, Shape: {}'.format(k, v) for k, v in tensor_structure.items()]))) + + +def _add_item_to_params(params, key, value): + """Adds a new item into `params`.""" + if isinstance(params, hparam.HParams): + # For HParams, we need to use special API. + if key in params: + params.key = value + else: + params.add_hparam(key, value) + else: + # Now params is Python dict. + params[key] = value diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py index c3882b8a27bc835f906c47dc5219f280c53800b8..6bdaa528f9f946ae4b9813d554409da2406b1f8d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.python.framework import dtypes from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops @@ -37,7 +38,8 @@ class TPUContextTest(test.TestCase): def testIsInContext(self): """Test that control_flow_util can check that we're in a TPU context.""" z1 = array_ops.identity(1) - context = tpu.TPUReplicateContext(b"context", 1) + pivot = control_flow_ops.no_op() + context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) context.Enter() z2 = array_ops.identity(1) context.Exit() diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py index f305197c190b67355338c407a7895a0507941ddb..df07ff44ee68230cd06723d87c2f60407120e8dc 100644 --- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py +++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py @@ -27,7 +27,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -506,19 +505,6 @@ class BatchSequencesWithStatesTest(test.TestCase): expected_seq4_batch2=expected_seq4_batch2) -class BatchSequencesWithStatesTestWithCApi(BatchSequencesWithStatesTest): - - def setUp(self): - self._prev_value = ops._USE_C_API - ops._USE_C_API = True - super(BatchSequencesWithStatesTestWithCApi, self).setUp() - - def tearDown(self): - super(BatchSequencesWithStatesTestWithCApi, self).tearDown() - ops._USE_C_API = self._prev_value - - -@test_util.with_c_api class PaddingTest(test.TestCase): def testPaddingInvalidLengths(self): diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index f0418f04ba2c5c12c882d0b678f182058f25a94f..3beb7bfe3048a8f0294f7e9149b5a07b5fcc7d17 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -34,7 +34,7 @@ from tensorflow.python.util import deprecation # where is either a single token or [] enclosed list of tokens. # For example: "var[1] = a" or "x = [1,2,3]" PARAM_RE = re.compile(r""" - (?P[a-zA-Z][\w]*) # variable name: "var" or "x" + (?P[a-zA-Z][\w\.]*) # variable name: "var" or "x" (\[\s*(?P\d+)\s*\])? # (optional) index: "1" or None \s*=\s* ((?P[^,\[]*) # single value: "a" or None @@ -200,6 +200,13 @@ def parse_values(values, type_map): If a hyperparameter name in both an index assignment and scalar assignment, a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). + The hyperparameter name may contain '.' symbols, which will result in an + attribute name that is only accessible through the getattr and setattr + functions. (And must be first explicit added through add_hparam.) + + WARNING: Use of '.' in your variable names is allowed, but is not well + supported and not recommended. + The `value` in `name=value` must follows the syntax according to the type of the parameter: diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 11fd15b5275a3c00b85bf986b2ff1ba0e2638aed..660c97f25e8458c345c8914bcaf98f37d047e50e 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -118,6 +118,21 @@ class HParamsTest(test.TestCase): self.assertEqual('2.3"', hparams2.c_c) self.assertEqual('/a=b/c/d', hparams2.d) + def testWithPeriodInVariableName(self): + hparams = hparam.HParams() + hparams.add_hparam(name='a.b', value=0.0) + hparams.parse('a.b=1.0') + self.assertEqual(1.0, getattr(hparams, 'a.b')) + hparams.add_hparam(name='c.d', value=0.0) + with self.assertRaisesRegexp(ValueError, 'Could not parse'): + hparams.parse('c.d=abc') + hparams.add_hparam(name='e.f', value='') + hparams.parse('e.f=abc') + self.assertEqual('abc', getattr(hparams, 'e.f')) + hparams.add_hparam(name='d..', value=0.0) + hparams.parse('d..=10.0') + self.assertEqual(10.0, getattr(hparams, 'd..')) + def testSetFromMap(self): hparams = hparam.HParams(a=1, b=2.0, c='tanh') hparams.override_from_dict({'a': -2, 'c': 'identity'}) diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py index 409aba817c1ec37003eb98f000f6cf8918234c5d..a2444934bc21d58ed57d15494b3548a31ce3a2df 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -45,14 +46,14 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset): self._input_dataset = input_dataset self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - # pylint: disable=protected-access if padded_shapes is None: self._padded_shapes = nest.map_structure( - dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes) + convert.partial_shape_to_tensor, input_dataset.output_shapes) else: self._padded_shapes = nest.map_structure_up_to( - input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor, + input_dataset.output_shapes, convert.partial_shape_to_tensor, padded_shapes) + # pylint: disable=protected-access padding_values = ( padding_values if padding_values is not None else dataset_ops._default_padding(input_dataset)) diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index 9720fd6e8657de18cf8d7565f834568ae52fdbda..1b45584dcb84fe62de6cc14017e7cae575f99b2f 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -58,7 +58,7 @@ cc_library( "//tensorflow/core/distributed_runtime/rpc:async_service_interface", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], alwayslink = 1, ) @@ -69,7 +69,7 @@ cc_library( hdrs = ["grpc_verbs_service_impl.h"], deps = [ ":verbs_service_proto_cc", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], ) diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc index 742f946c9536973eb8a6a11afda1b32ae4a7726b..af29abd91feda22824e57c19c13a3f48fb1d61b7 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc @@ -15,9 +15,9 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include "grpc++/alarm.h" -#include "grpc++/grpc++.h" -#include "grpc++/server_builder.h" +#include "grpcpp/alarm.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/server_builder.h" #include "tensorflow/contrib/verbs/grpc_verbs_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc index 991f9a9d8bdf883b1b68bfa1fb6af7bf51b7e66a..4da7b59c69c88a4d04be37543aae7f03decd2c52 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h" -#include "grpc++/impl/codegen/async_stream.h" -#include "grpc++/impl/codegen/async_unary_call.h" -#include "grpc++/impl/codegen/channel_interface.h" -#include "grpc++/impl/codegen/client_unary_call.h" -#include "grpc++/impl/codegen/method_handler_impl.h" -#include "grpc++/impl/codegen/rpc_service_method.h" -#include "grpc++/impl/codegen/service_type.h" -#include "grpc++/impl/codegen/sync_stream.h" +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/channel_interface.h" +#include "grpcpp/impl/codegen/client_unary_call.h" +#include "grpcpp/impl/codegen/method_handler_impl.h" +#include "grpcpp/impl/codegen/rpc_service_method.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/sync_stream.h" namespace tensorflow { diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h index 1f0f10517e98a32ae882c027330091928f1a6ee2..abe5e08b07cd71b7ca28321e6eb2cf0eec5d1b0f 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ #define TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ -#include "grpc++/impl/codegen/async_stream.h" -#include "grpc++/impl/codegen/async_unary_call.h" -#include "grpc++/impl/codegen/proto_utils.h" -#include "grpc++/impl/codegen/rpc_method.h" -#include "grpc++/impl/codegen/service_type.h" -#include "grpc++/impl/codegen/status.h" -#include "grpc++/impl/codegen/stub_options.h" -#include "grpc++/impl/codegen/sync_stream.h" +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/proto_utils.h" +#include "grpcpp/impl/codegen/rpc_method.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/status.h" +#include "grpcpp/impl/codegen/stub_options.h" +#include "grpcpp/impl/codegen/sync_stream.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 277f27f2688812653112eb38b4643acafbe6d414..8beabbc84d2f82492d1f4ce91e8cd238a592b33c 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -72,77 +72,76 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "cc_header_only_library", "full_path", "if_android", - "if_not_android_mips_and_mips64", "if_ios", "if_linux_x86_64", "if_mobile", "if_not_mobile", - "if_windows", "if_not_windows", - "tf_copts", + "if_windows", "tf_cc_test", "tf_cc_tests", + "tf_copts", "tf_cuda_library", "tf_gen_op_libs", "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", - "cc_header_only_library", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule") load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test") # For platform specific build config load( "//tensorflow/core:platform/default/build_config.bzl", - "tf_platform_hdrs", - "tf_platform_srcs", - "tf_proto_library", - "tf_proto_library_cc", "tf_additional_all_protos", + "tf_additional_cloud_kernel_deps", + "tf_additional_cloud_op_deps", "tf_additional_core_deps", + "tf_additional_cupti_wrapper_deps", + "tf_additional_device_tracer_cuda_deps", + "tf_additional_device_tracer_deps", + "tf_additional_device_tracer_srcs", + "tf_additional_gdr_lib_defines", + "tf_additional_human_readable_json_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", "tf_additional_lib_hdrs", "tf_additional_lib_srcs", - "tf_additional_framework_hdrs", - "tf_additional_framework_srcs", - "tf_additional_minimal_lib_srcs", - "tf_additional_proto_hdrs", - "tf_additional_proto_srcs", - "tf_additional_cupti_wrapper_deps", "tf_additional_libdevice_data", "tf_additional_libdevice_deps", "tf_additional_libdevice_srcs", + "tf_additional_minimal_lib_srcs", + "tf_additional_mpi_lib_defines", + "tf_additional_proto_hdrs", + "tf_additional_proto_srcs", "tf_additional_test_deps", "tf_additional_test_srcs", - "tf_kernel_tests_linkstatic", - "tf_additional_cloud_op_deps", - "tf_additional_cloud_kernel_deps", - "tf_lib_proto_parsing_deps", "tf_additional_verbs_lib_defines", - "tf_additional_mpi_lib_defines", - "tf_additional_gdr_lib_defines", - "tf_additional_device_tracer_srcs", - "tf_additional_device_tracer_deps", - "tf_additional_device_tracer_cuda_deps", - "tf_pyclif_proto_library", "tf_jspb_proto_library", + "tf_kernel_tests_linkstatic", + "tf_lib_proto_parsing_deps", "tf_nano_proto_library", + "tf_platform_hdrs", + "tf_platform_srcs", + "tf_proto_library", + "tf_proto_library_cc", "tf_protos_all", "tf_protos_all_impl", "tf_protos_grappler", "tf_protos_grappler_impl", + "tf_pyclif_proto_library", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", "if_static", + "tf_cuda_tests_tags", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library") @@ -225,6 +224,7 @@ ADDITIONAL_CORE_PROTO_SRCS = [ "protobuf/named_tensor.proto", "protobuf/saved_model.proto", "protobuf/tensorflow_server.proto", + "protobuf/transport_options.proto", "util/test_log.proto", ] @@ -232,6 +232,7 @@ tf_proto_library( name = "protos_all", srcs = [], cc_api_version = 2, + dart_api_version = 2, default_header = True, j2objc_api_version = 1, java_api_version = 2, @@ -267,6 +268,12 @@ proto_library( visibility = ["//visibility:public"], ) +closure_proto_library( + name = "example_protos_closure", + visibility = ["//visibility:public"], + deps = [":example_protos"], +) + exports_files([ "framework/types.proto", ]) @@ -287,42 +294,18 @@ cc_library( ], ) -PLATFORM_BASE_HDRS = [ - "platform/env_time.h", - "platform/logging.h", - "platform/macros.h", - "platform/types.h", - "platform/byte_order.h", -] - -PLATFORM_OTHER_HDRS = [ - "platform/abi.h", - "platform/stacktrace.h", - "platform/stacktrace_handler.h", - "platform/context.h", - "platform/cpu_info.h", - "platform/cpu_feature_guard.h", - "platform/dynamic_annotations.h", - "platform/env.h", - "platform/file_system.h", - "platform/file_system_helper.h", - "platform/fingerprint.h", - "platform/init_main.h", - "platform/mem.h", - "platform/mutex.h", - "platform/net.h", - "platform/notification.h", - "platform/null_file_system.h", - "platform/prefetch.h", - "platform/profile_utils/clock_cycle_profiler.h", - "platform/profile_utils/cpu_utils.h", - "platform/protobuf.h", - "platform/strong_hash.h", - "platform/subprocess.h", - "platform/thread_annotations.h", -] +filegroup( + name = "platform_base_hdrs", + srcs = [ + "platform/byte_order.h", + "platform/env_time.h", + "platform/logging.h", + "platform/macros.h", + "platform/types.h", + ], + visibility = ["//visibility:private"], +) -# Smaller platform libraries that don't depend on "lib" or "lib_internal". cc_library( name = "platform_base", srcs = tf_platform_hdrs([ @@ -334,14 +317,275 @@ cc_library( ]) + [ "platform/env_time.cc", ], - hdrs = PLATFORM_BASE_HDRS, + hdrs = [":platform_base_hdrs"], copts = tf_copts(), + tags = ["avoid_dep"], + visibility = ["//tensorflow/core:__subpackages__"], deps = [ ":lib_platform", "//tensorflow/core/platform/default/build_config:base", ], ) +filegroup( + name = "platform_port_hdrs", + srcs = [ + "platform/cpu_info.h", + "platform/dynamic_annotations.h", + "platform/init_main.h", + "platform/mem.h", + "platform/mutex.h", + "platform/thread_annotations.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_port_internal_hdrs", + srcs = [ + "platform/demangle.h", + "platform/host_info.h", + "platform/snappy.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_port", + srcs = tf_platform_hdrs([ + "cpu_info.h", + "dynamic_annotations.h", + "thread_annotations.h", + "mutex.h", + ]) + tf_platform_srcs([ + "port.cc", + ]) + [ + "platform/cpu_info.cc", + ], + hdrs = [ + ":platform_port_hdrs", + ":platform_port_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib_platform", + ":platform_base", + "//tensorflow/core/platform/default/build_config:port", + "@snappy", + ], +) + +filegroup( + name = "platform_protobuf_hdrs", + srcs = [ + "platform/protobuf.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_protobuf_internal_hdrs", + srcs = [ + "platform/protobuf_internal.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_protobuf", + srcs = tf_platform_hdrs([ + "protobuf.h", + ]) + tf_platform_srcs([ + "protobuf.cc", + ]) + [ + "platform/protobuf_util.cc", + "lib/core/status.h", + ], + hdrs = [ + ":platform_protobuf_hdrs", + ":platform_protobuf_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib_platform", + ":platform_base", + ":platform_port", + "//tensorflow/core/platform/default/build_config:protobuf", + "@protobuf_archive//:protobuf", + ], +) + +cc_library( + name = "human_readable_json", + srcs = tf_platform_srcs(["human_readable_json.cc"]), + hdrs = ["platform/human_readable_json.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":lib", + ":lib_internal", + ] + tf_additional_human_readable_json_deps(), +) + +filegroup( + name = "platform_env_hdrs", + srcs = [ + "platform/env.h", + "platform/file_statistics.h", + "platform/file_system.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_env_internal_hdrs", + srcs = [ + "platform/load_library.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_env", + srcs = tf_platform_srcs([ + "env.cc", + "load_library.cc", + ]) + tf_platform_hdrs([ + "wide_char.h", + ]) + [ + "platform/env.cc", + "platform/file_system.cc", + ], + hdrs = [ + ":platform_env_hdrs", + ":platform_env_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":error_codes_proto_cc", + ":lib", + ":lib_internal", + ":lib_platform", + ":platform_base", + ":platform_port", + ":platform_protobuf", + "//tensorflow/core/platform/default/build_config:env", + ], +) + +filegroup( + name = "platform_file_system_hdrs", + srcs = [ + "platform/file_system_helper.h", + "platform/null_file_system.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_file_system", + srcs = tf_platform_srcs([ + ]) + tf_platform_hdrs([ + "windows_file_system.h", + ]) + [ + "platform/file_system_helper.cc", + ], + hdrs = [ + ":platform_file_system_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib", + ":lib_platform", + ":platform_env", + ], +) + +filegroup( + name = "platform_other_hdrs", + srcs = [ + "platform/abi.h", + "platform/context.h", + "platform/cpu_feature_guard.h", + "platform/error.h", + "platform/fingerprint.h", + "platform/net.h", + "platform/notification.h", + "platform/prefetch.h", + "platform/profile_utils/android_armv7a_cpu_utils_helper.h", + "platform/profile_utils/clock_cycle_profiler.h", + "platform/profile_utils/cpu_utils.h", + "platform/profile_utils/i_cpu_utils_helper.h", + "platform/stacktrace.h", + "platform/stacktrace_handler.h", + "platform/strong_hash.h", + "platform/subprocess.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_other_internal_hdrs", + srcs = [ + "platform/denormal.h", + "platform/setround.h", + "platform/tracing.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_other", + srcs = tf_platform_srcs([ + "subprocess.cc", + "net.cc", + "tracing.cc", + ]) + tf_platform_hdrs([ + "tracing.h", + "error.h", + "context.h", + "fingerprint.h", + "notification.h", + "stacktrace.h", + "strong_hash.h", + "subprocess.h", + "tracing_impl.h", + ]) + [ + "platform/cpu_feature_guard.cc", + "platform/setround.cc", + "platform/tracing.cc", + "platform/denormal.cc", + "platform/profile_utils/android_armv7a_cpu_utils_helper.cc", + "platform/profile_utils/clock_cycle_profiler.cc", + "platform/profile_utils/cpu_utils.cc", + ], + hdrs = [ + ":platform_other_hdrs", + ":platform_other_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib", + ":lib_platform", + ":platform_base", + ":platform_env", + ":platform_port", + ":platform_protobuf", + "//tensorflow/core/platform/default/build_config:other", + "//tensorflow/core/platform/default/build_config:platformlib", + "//tensorflow/core/platform/default/build_config:port", + ], +) + # Minimal lib so that tools used for mobile compilation # don't have to depend on lib/platformlib. cc_library( @@ -375,8 +619,7 @@ cc_library( # tf_cc_test and tf_cc_binary will include the necessary symbols. cc_library( name = "lib", - hdrs = PLATFORM_BASE_HDRS + - PLATFORM_OTHER_HDRS + [ + hdrs = [ "lib/bfloat16/bfloat16.h", "lib/core/arena.h", "lib/core/bitmap.h", @@ -423,6 +666,12 @@ cc_library( "lib/strings/str_util.h", "lib/strings/strcat.h", "lib/strings/stringprintf.h", + ":platform_base_hdrs", + ":platform_env_hdrs", + ":platform_file_system_hdrs", + ":platform_other_hdrs", + ":platform_port_hdrs", + ":platform_protobuf_hdrs", ], visibility = ["//visibility:public"], deps = [ @@ -578,6 +827,7 @@ tf_cuda_library( "framework/types.h", "public/version.h", "util/activation_mode.h", + "util/batch_util.h", "util/bcast.h", "util/cuda_kernel_helper.h", "util/device_name_utils.h", @@ -595,6 +845,7 @@ tf_cuda_library( "util/sparse/group_iterator.h", "util/sparse/sparse_tensor.h", "util/stat_summarizer.h", + "util/stat_summarizer_options.h", "util/stream_executor_util.h", "util/strided_slice_op.h", "util/tensor_format.h", @@ -619,6 +870,18 @@ tf_cuda_library( deps = [":framework_internal"], ) +cc_library( + name = "stats_calculator_portable", + srcs = [ + "util/stat_summarizer_options.h", + "util/stats_calculator.cc", + ], + hdrs = [ + "util/stats_calculator.h", + ], + copts = tf_copts(), +) + cc_library( name = "overflow", hdrs = ["util/overflow.h"], @@ -628,6 +891,12 @@ cc_library( ], ) +cc_library( + name = "exec_on_stall", + hdrs = ["util/exec_on_stall.h"], + deps = [":framework_lite"], +) + cc_library( name = "ptr_util", hdrs = ["util/ptr_util.h"], @@ -801,8 +1070,6 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor", "//tensorflow/core/kernels:bounds_check_lib", - "//third_party/eigen3", - "@farmhash_archive//:farmhash", ], alwayslink = 1, ) @@ -1101,6 +1368,7 @@ cc_library( ":shape_inference_testutil", ":tensor_testutil", ":test", + ":testlib_ops", "//tensorflow/cc:scope", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:ops_testutil", @@ -1108,6 +1376,18 @@ cc_library( ], ) +cc_library( + name = "testlib_ops", + testonly = 1, + srcs = ["common_runtime/testlib_ops.cc"], + linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + # This is a link-only library to provide a DirectSession # implementation of the Session interface. tf_cuda_library( @@ -1169,6 +1449,7 @@ filegroup( "lib/png/**/*", "lib/gif/**/*", "util/events_writer.*", + "util/stats_calculator.*", "util/reporter.*", "platform/**/cuda_libdevice_path.*", "platform/default/test_benchmark.*", @@ -1252,6 +1533,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protos_all_cc_impl", + ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", "@nsync//:nsync_cpp", @@ -1292,6 +1574,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protos_all_cc_impl", + ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", "@nsync//:nsync_cpp", @@ -1518,6 +1801,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "framework/cost_graph_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "framework/cost_graph.proto", + visibility = ["//visibility:public"], +) + tf_pyclif_proto_library( name = "framework/tensor_pyclif", proto_lib = ":protos_all_cc", @@ -1655,13 +1945,6 @@ LIB_INTERNAL_PRIVATE_HEADERS = ["framework/resource_handle.h"] + glob( "platform/**/cuda.h", "platform/**/stream_executor.h", ], -) + tf_additional_lib_srcs( - exclude = [ - "**/*.cc", - "**/*test*", - "platform/**/cuda.h", - "platform/**/stream_executor.h", - ], ) LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ @@ -1757,9 +2040,8 @@ cc_library( "platform/**/cuda_libdevice_path.cc", "platform/**/device_tracer.cc", "platform/**/logging.cc", + "platform/**/human_readable_json.cc", "platform/abi.cc", - "platform/variant_coding.cc", - "platform/**/variant_cord_coding.cc", ], ) + tf_additional_lib_srcs( exclude = [ @@ -1771,9 +2053,8 @@ cc_library( "platform/**/env_time.cc", "platform/**/device_tracer.cc", "platform/**/logging.cc", + "platform/**/human_readable_json.cc", "platform/abi.cc", - "platform/variant_coding.cc", - "platform/**/variant_cord_coding.cc", ] + # Protobuf deps already included through the ":lib_proto_parsing" # dependency. @@ -1957,10 +2238,12 @@ tf_proto_library( name = "error_codes_proto", srcs = ERROR_CODES_PROTO_SRCS, cc_api_version = 2, + dart_api_version = 2, default_header = True, j2objc_api_version = 1, java_api_version = 2, js_api_version = 2, + provide_cc_alias = True, ) tf_generate_proto_text_sources( @@ -1978,6 +2261,7 @@ tf_proto_library( name = "protos_all_proto", srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS, cc_api_version = 2, + dart_api_version = 2, default_header = True, j2objc_api_version = 1, java_api_version = 2, @@ -2022,7 +2306,6 @@ cc_library( ) FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ - "platform/variant_coding.h", "graph/edgeset.h", "graph/graph.h", "graph/graph_def_builder.h", @@ -2063,14 +2346,13 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/tracking_allocator.h", # only needed for tests "framework/unique_tensor_references.h", "framework/variant.h", - "platform/variant_coding.h", "util/command_line_flags.h", "util/env_var.h", "util/equal_graph_def.h", "util/presized_cuckoo_map.h", "util/tensor_slice_set.h", "util/tensor_slice_util.h", -] + tf_additional_framework_hdrs() +] tf_cuda_library( name = "framework_internal", @@ -2112,9 +2394,7 @@ cc_header_only_library( tf_cuda_library( name = "framework_internal_impl", - srcs = FRAMEWORK_INTERNAL_PRIVATE_HEADERS + [ - "platform/variant_coding.cc", - ] + glob( + srcs = FRAMEWORK_INTERNAL_PRIVATE_HEADERS + glob( [ "example/**/*.cc", "framework/**/*.cc", @@ -2148,7 +2428,7 @@ tf_cuda_library( "util/memmapped_file_system.cc", "util/memmapped_file_system_writer.cc", ], - }) + tf_additional_framework_srcs(), + }), hdrs = FRAMEWORK_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), linkopts = select({ @@ -2357,8 +2637,10 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/dma_helper.h", "common_runtime/eigen_thread_pool.h", "common_runtime/executor.h", + "common_runtime/executor_factory.h", "common_runtime/graph_optimizer.h", "common_runtime/local_device.h", + "common_runtime/lower_if_op.h", "common_runtime/memory_types.h", "common_runtime/mkl_cpu_allocator.h", "common_runtime/optimization_registry.h", @@ -2405,10 +2687,12 @@ tf_cuda_library( "common_runtime/device_resolver_local.cc", "common_runtime/device_set.cc", "common_runtime/executor.cc", + "common_runtime/executor_factory.cc", "common_runtime/function.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", "common_runtime/local_device.cc", + "common_runtime/lower_if_op.cc", "common_runtime/memory_types.cc", "common_runtime/mkl_cpu_allocator.cc", "common_runtime/optimization_registry.cc", @@ -2511,6 +2795,7 @@ cc_library( ], visibility = [ "//tensorflow/compiler:__subpackages__", + "//tensorflow/core/kernels:__subpackages__", "//tensorflow/core/profiler:__subpackages__", ], deps = [":lib_internal"], @@ -2566,6 +2851,7 @@ tf_cuda_library( ], copts = tf_copts(), cuda_deps = tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps(), + visibility = ["//visibility:private"], deps = [ ":core_cpu_internal", ":lib", @@ -2983,6 +3269,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "exec_on_stall_test", + size = "small", + srcs = ["util/exec_on_stall_test.cc"], + deps = [ + ":exec_on_stall", + ":framework_lite", + ":test", + ":test_main", + ], +) + tf_cc_test( name = "lib_jpeg_jpeg_mem_unittest", srcs = ["lib/jpeg/jpeg_mem_unittest.cc"], @@ -3277,7 +3575,10 @@ tf_cc_tests_gpu( tf_cc_test_mkl( name = "mkl_runtime_tests", size = "small", - srcs = ["common_runtime/mkl_cpu_allocator_test.cc"], + srcs = [ + "common_runtime/mkl_cpu_allocator_test.cc", + "common_runtime/mkl_threadpool_device_test.cc", + ], linkstatic = 1, deps = [ ":core", @@ -3379,6 +3680,37 @@ tf_cc_tests_gpu( ], ) +tf_cuda_cc_test( + name = "gpu_device_unified_memory_test", + size = "small", + srcs = [ + "common_runtime/gpu/gpu_device_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + # Runs test on a Guitar cluster that uses P100s to test unified memory + # allocations. + tags = tf_cuda_tests_tags() + [ + "guitar", + "multi_gpu", + ], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":direct_session", + ":framework", + ":framework_internal", + ":gpu_id", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:ops_util", + ], +) + tf_cc_test_gpu( name = "cuda_libdevice_path_test", size = "small", @@ -3712,6 +4044,31 @@ tf_cc_test( ], ) +tf_cc_test( + name = "common_runtime_executor_test", + size = "small", + srcs = ["common_runtime/executor_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:state", + ], +) + tf_cc_test( name = "common_runtime_function_test", size = "small", @@ -4072,6 +4429,29 @@ tf_cc_test_gpu( ], ) +tf_cc_tests( + name = "common_runtime_lower_if_op_test", + size = "small", + srcs = ["common_runtime/lower_if_op_test.cc"], + deps = [ + ":all_kernels", + ":core_cpu", + ":core_cpu_internal", + ":direct_session", + ":framework", + ":framework_internal", + ":lib", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:client_session", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + ], +) + # Test data filegroup( name = "image_testdata", @@ -4165,9 +4545,3 @@ alias( actual = ":mobile_srcs", visibility = ["//visibility:public"], ) - -closure_proto_library( - name = "example_protos_closure", - visibility = ["//visibility:public"], - deps = [":example_protos"], -) diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d8c2ed40a324d4854d83c471e8eef50e50277b93 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt @@ -0,0 +1,13 @@ +op { + graph_op_name: "AnonymousIterator" + out_arg { + name: "handle" + description: <